[mlir][MemRef] Deprecate unspecified trailing offset, size, and strides semantics...
authorMaheshRavishankar <ravishankarm@google.com>
Wed, 29 Dec 2021 18:48:02 +0000 (10:48 -0800)
committerMaheshRavishankar <ravishankarm@google.com>
Wed, 29 Dec 2021 19:18:29 +0000 (11:18 -0800)
The semantics of the ops that implement the
`OffsetSizeAndStrideOpInterface` is that if the number of offsets,
sizes or strides are less than the rank of the source, then some
default values are filled along the trailing dimensions (0 for offset,
source dimension of sizes, and 1 for strides). This is confusing,
especially with rank-reducing semantics. Immediate issue here is that
the methods of `OffsetSizeAndStridesOpInterface` assumes that the
number of values is same as the source rank. This cause out-of-bounds
errors.

So simplifying the specification of `OffsetSizeAndStridesOpInterface`
to make it invalid to specify number of offsets/sizes/strides not
equal to the source rank.

Differential Revision: https://reviews.llvm.org/D115677

13 files changed:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/MemRef/subview.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
mlir/test/Integration/Dialect/Standard/CPU/test_subview.mlir

index aa20137..ab7e830 100644 (file)
@@ -1495,28 +1495,14 @@ Wrapper operator*(Wrapper a, int64_t b) {
 /// static representation of offsets, sizes and strides. Special sentinels
 /// encode the dynamic case.
 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
-                                ArrayRef<int64_t> leadingStaticOffsets,
-                                ArrayRef<int64_t> leadingStaticSizes,
-                                ArrayRef<int64_t> leadingStaticStrides) {
-  // A subview may specify only a leading subset of offset/sizes/strides in
-  // which case we complete with offset=0, sizes from memref type and strides=1.
+                                ArrayRef<int64_t> staticOffsets,
+                                ArrayRef<int64_t> staticSizes,
+                                ArrayRef<int64_t> staticStrides) {
   unsigned rank = sourceMemRefType.getRank();
-  assert(leadingStaticOffsets.size() <= rank &&
-         "unexpected leadingStaticOffsets overflow");
-  assert(leadingStaticSizes.size() <= rank &&
-         "unexpected leadingStaticSizes overflow");
-  assert(leadingStaticStrides.size() <= rank &&
-         "unexpected leadingStaticStrides overflow");
-  auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
-  auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
-  auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
-  unsigned numTrailingOffsets = rank - staticOffsets.size();
-  unsigned numTrailingSizes = rank - staticSizes.size();
-  unsigned numTrailingStrides = rank - staticStrides.size();
-  staticOffsets.append(numTrailingOffsets, 0);
-  llvm::append_range(staticSizes,
-                     sourceMemRefType.getShape().take_back(numTrailingSizes));
-  staticStrides.append(numTrailingStrides, 1);
+  (void)rank;
+  assert(staticOffsets.size() == rank && "unexpected staticOffsets overflow");
+  assert(staticSizes.size() == rank && "unexpected staticSizes overflow");
+  assert(staticStrides.size() == rank && "unexpected staticStrides overflow");
 
   // Extract source offset and strides.
   int64_t sourceOffset;
@@ -1553,29 +1539,28 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
 }
 
 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
-                                ArrayRef<OpFoldResult> leadingStaticOffsets,
-                                ArrayRef<OpFoldResult> leadingStaticSizes,
-                                ArrayRef<OpFoldResult> leadingStaticStrides) {
+                                ArrayRef<OpFoldResult> offsets,
+                                ArrayRef<OpFoldResult> sizes,
+                                ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
-  dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
-                             staticOffsets, ShapedType::kDynamicStrideOrOffset);
-  dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+                             ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
                              ShapedType::kDynamicSize);
-  dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
-                             staticStrides, ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+                             ShapedType::kDynamicStrideOrOffset);
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
                                     staticSizes, staticStrides);
 }
 
-Type SubViewOp::inferRankReducedResultType(
-    unsigned resultRank, MemRefType sourceRankedTensorType,
-    ArrayRef<int64_t> leadingStaticOffsets,
-    ArrayRef<int64_t> leadingStaticSizes,
-    ArrayRef<int64_t> leadingStaticStrides) {
+Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
+                                           MemRefType sourceRankedTensorType,
+                                           ArrayRef<int64_t> offsets,
+                                           ArrayRef<int64_t> sizes,
+                                           ArrayRef<int64_t> strides) {
   auto inferredType =
-      inferResultType(sourceRankedTensorType, leadingStaticOffsets,
-                      leadingStaticSizes, leadingStaticStrides)
+      inferResultType(sourceRankedTensorType, offsets, sizes, strides)
           .cast<MemRefType>();
   assert(inferredType.getRank() >= resultRank && "expected ");
   int rankDiff = inferredType.getRank() - resultRank;
@@ -1598,19 +1583,19 @@ Type SubViewOp::inferRankReducedResultType(
   return inferredType;
 }
 
-Type SubViewOp::inferRankReducedResultType(
-    unsigned resultRank, MemRefType sourceRankedTensorType,
-    ArrayRef<OpFoldResult> leadingStaticOffsets,
-    ArrayRef<OpFoldResult> leadingStaticSizes,
-    ArrayRef<OpFoldResult> leadingStaticStrides) {
+Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
+                                           MemRefType sourceRankedTensorType,
+                                           ArrayRef<OpFoldResult> offsets,
+                                           ArrayRef<OpFoldResult> sizes,
+                                           ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
-  dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
-                             staticOffsets, ShapedType::kDynamicStrideOrOffset);
-  dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+                             ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
                              ShapedType::kDynamicSize);
-  dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
-                             staticStrides, ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+                             ShapedType::kDynamicStrideOrOffset);
   return SubViewOp::inferRankReducedResultType(
       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
       staticStrides);
@@ -1893,6 +1878,43 @@ static MemRefType getCanonicalSubViewResultType(
                                        mixedStrides);
 }
 
+/// Helper method to check if a `subview` operation is trivially a no-op. This
+/// is the case if the all offsets are zero, all strides are 1, and the source
+/// shape is same as the size of the subview. In such cases, the subview can be
+/// folded into its source.
+static bool isTrivialSubViewOp(SubViewOp subViewOp) {
+  if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
+    return false;
+
+  auto mixedOffsets = subViewOp.getMixedOffsets();
+  auto mixedSizes = subViewOp.getMixedSizes();
+  auto mixedStrides = subViewOp.getMixedStrides();
+
+  // Check offsets are zero.
+  if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
+        Optional<int64_t> intValue = getConstantIntValue(ofr);
+        return !intValue || intValue.getValue() != 0;
+      }))
+    return false;
+
+  // Check strides are one.
+  if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
+        Optional<int64_t> intValue = getConstantIntValue(ofr);
+        return !intValue || intValue.getValue() != 1;
+      }))
+    return false;
+
+  // Check all size values are static and matches the (static) source shape.
+  ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
+  for (auto size : llvm::enumerate(mixedSizes)) {
+    Optional<int64_t> intValue = getConstantIntValue(size.value());
+    if (!intValue || intValue.getValue() != sourceShape[size.index()])
+      return false;
+  }
+  // All conditions met. The `SubViewOp` is foldable as a no-op.
+  return true;
+}
+
 namespace {
 /// Pattern to rewrite a subview op with MemRefCast arguments.
 /// This essentially pushes memref.cast past its consuming subview when
@@ -1950,6 +1972,26 @@ public:
     return success();
   }
 };
+
+/// Canonicalize subview ops that are no-ops. When the source shape is not same
+/// as a result shape due to use of `affine_map`.
+class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
+public:
+  using OpRewritePattern<SubViewOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SubViewOp subViewOp,
+                                PatternRewriter &rewriter) const override {
+    if (!isTrivialSubViewOp(subViewOp))
+      return failure();
+    if (subViewOp.getSourceType() == subViewOp.getType()) {
+      rewriter.replaceOp(subViewOp, subViewOp.source());
+      return success();
+    }
+    rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(),
+                                        subViewOp.getType());
+    return success();
+  }
+};
 } // namespace
 
 /// Return the canonical type of the result of a subview.
@@ -1975,7 +2017,7 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results
       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
-           SubViewOpMemRefCastFolder>(context);
+           SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
 }
 
 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
index cec8b2c..f766513 100644 (file)
@@ -827,38 +827,31 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
 /// An extract_slice op result type can be fully inferred from the source type
 /// and the static representation of offsets, sizes and strides. Special
 /// sentinels encode the dynamic case.
-RankedTensorType
-ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
-                                ArrayRef<int64_t> leadingStaticOffsets,
-                                ArrayRef<int64_t> leadingStaticSizes,
-                                ArrayRef<int64_t> leadingStaticStrides) {
+RankedTensorType ExtractSliceOp::inferResultType(
+    RankedTensorType sourceRankedTensorType, ArrayRef<int64_t> staticOffsets,
+    ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
   // An extract_slice op may specify only a leading subset of offset/sizes/
   // strides in which case we complete with offset=0, sizes from memref type and
   // strides=1.
   unsigned rank = sourceRankedTensorType.getRank();
-  assert(leadingStaticSizes.size() <= rank &&
-         "unexpected leadingStaticSizes overflow");
-  auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
-  unsigned numTrailingSizes = rank - staticSizes.size();
-  llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back(
-                                      numTrailingSizes));
+  (void)rank;
+  assert(staticSizes.size() == rank &&
+         "unexpected staticSizes not equal to rank of source");
   return RankedTensorType::get(staticSizes,
                                sourceRankedTensorType.getElementType());
 }
 
-RankedTensorType
-ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
-                                ArrayRef<OpFoldResult> leadingStaticOffsets,
-                                ArrayRef<OpFoldResult> leadingStaticSizes,
-                                ArrayRef<OpFoldResult> leadingStaticStrides) {
+RankedTensorType ExtractSliceOp::inferResultType(
+    RankedTensorType sourceRankedTensorType, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
-  dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
-                             staticOffsets, ShapedType::kDynamicStrideOrOffset);
-  dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+                             ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
                              ShapedType::kDynamicSize);
-  dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
-                             staticStrides, ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+                             ShapedType::kDynamicStrideOrOffset);
   return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
                                          staticSizes, staticStrides);
 }
@@ -868,12 +861,10 @@ ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
 /// sentinels encode the dynamic case.
 RankedTensorType ExtractSliceOp::inferRankReducedResultType(
     unsigned resultRank, RankedTensorType sourceRankedTensorType,
-    ArrayRef<int64_t> leadingStaticOffsets,
-    ArrayRef<int64_t> leadingStaticSizes,
-    ArrayRef<int64_t> leadingStaticStrides) {
+    ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
+    ArrayRef<int64_t> strides) {
   auto inferredType =
-      inferResultType(sourceRankedTensorType, leadingStaticOffsets,
-                      leadingStaticSizes, leadingStaticStrides)
+      inferResultType(sourceRankedTensorType, offsets, sizes, strides)
           .cast<RankedTensorType>();
   int rankDiff = inferredType.getRank() - resultRank;
   if (rankDiff > 0) {
@@ -892,17 +883,16 @@ RankedTensorType ExtractSliceOp::inferRankReducedResultType(
 
 RankedTensorType ExtractSliceOp::inferRankReducedResultType(
     unsigned resultRank, RankedTensorType sourceRankedTensorType,
-    ArrayRef<OpFoldResult> leadingStaticOffsets,
-    ArrayRef<OpFoldResult> leadingStaticSizes,
-    ArrayRef<OpFoldResult> leadingStaticStrides) {
+    ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+    ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
-  dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
-                             staticOffsets, ShapedType::kDynamicStrideOrOffset);
-  dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+                             ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
                              ShapedType::kDynamicSize);
-  dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
-                             staticStrides, ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+                             ShapedType::kDynamicStrideOrOffset);
   return ExtractSliceOp::inferRankReducedResultType(
       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
       staticStrides);
@@ -919,12 +909,10 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
-
                              ShapedType::kDynamicStrideOrOffset);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
                              ShapedType::kDynamicSize);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
-
                              ShapedType::kDynamicStrideOrOffset);
   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
   // Structuring implementation this way avoids duplication between builders.
@@ -1225,12 +1213,10 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
-
                              ShapedType::kDynamicStrideOrOffset);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
                              ShapedType::kDynamicSize);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
-
                              ShapedType::kDynamicStrideOrOffset);
   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
index 9b1ae7a..8b4cd7e 100644 (file)
@@ -212,10 +212,11 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
 }
 
 /// Drops unit dimensions from the input MemRefType.
-static MemRefType dropUnitDims(MemRefType inputType) {
-  ArrayRef<int64_t> none{};
+static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
+                               ArrayRef<int64_t> sizes,
+                               ArrayRef<int64_t> strides) {
   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
-      0, inputType, none, none, none);
+      0, inputType, offsets, sizes, strides);
   return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
 }
 
@@ -226,15 +227,16 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
                                                  Value input) {
   MemRefType inputType = input.getType().cast<MemRefType>();
   assert(inputType.hasStaticShape());
-  MemRefType resultType = dropUnitDims(inputType);
+  SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
+  SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
+  ArrayRef<int64_t> subViewSizes = inputType.getShape();
+  MemRefType resultType =
+      dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
   if (canonicalizeStridedLayout(resultType) ==
       canonicalizeStridedLayout(inputType))
     return input;
-  SmallVector<int64_t> subviewOffsets(inputType.getRank(), 0);
-  SmallVector<int64_t> subviewStrides(inputType.getRank(), 1);
   return rewriter.create<memref::SubViewOp>(
-      loc, resultType, input, subviewOffsets, inputType.getShape(),
-      subviewStrides);
+      loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
 }
 
 /// Returns the number of dims that aren't unit dims.
index 4a963a1..6394895 100644 (file)
@@ -18,12 +18,12 @@ using namespace mlir;
 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
 
 LogicalResult mlir::verifyListOfOperandsOrIntegers(
-    Operation *op, StringRef name, unsigned maxNumElements, ArrayAttr attr,
+    Operation *op, StringRef name, unsigned numElements, ArrayAttr attr,
     ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
   /// Check static and dynamic offsets/sizes/strides does not overflow type.
-  if (attr.size() > maxNumElements)
-    return op->emitError("expected <= ")
-           << maxNumElements << " " << name << " values";
+  if (attr.size() != numElements)
+    return op->emitError("expected ")
+           << numElements << " " << name << " values";
   unsigned expectedNumDynamicEntries =
       llvm::count_if(attr.getValue(), [&](Attribute attr) {
         return isDynamic(attr.cast<IntegerAttr>().getInt());
index 009106f..5682c85 100644 (file)
@@ -448,7 +448,7 @@ func @subview_leading_operands(%0 : memref<5x3xf32>, %1: memref<5x?xf32>) {
   // CHECK: %[[C3_3:.*]] = llvm.mlir.constant(3 : i64) : i64
   // CHECK: llvm.insertvalue %[[C3_2]], %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: llvm.insertvalue %[[C3_3]], %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-  %2 = memref.subview %0[2][3][1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]>
+  %2 = memref.subview %0[2, 0][3, 3][1, 1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]>
 
   return
 }
@@ -466,13 +466,15 @@ func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) {
   // CHECK: %[[ST0:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[ST1:.*]] = llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
   // Compute and insert offset from 2 + dynamic value.
-  // CHECK: %[[OFF:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[OFF0:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
-  // CHECK: %[[MUL:.*]] = llvm.mul %[[C2]], %[[ST0]] : i64
-  // CHECK: %[[NEW_OFF:.*]] = llvm.add %[[OFF]], %[[MUL]]  : i64
+  // CHECK: %[[MUL0:.*]] = llvm.mul %[[C2]], %[[ST0]] : i64
+  // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[MUL0]] : i64
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: %[[MUL1:.*]] = llvm.mul %[[C0]], %[[ST1]] : i64
+  // CHECK: %[[NEW_OFF:.*]] = llvm.add %[[OFF1]], %[[MUL1]]  : i64
   // CHECK: llvm.insertvalue %[[NEW_OFF]], %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
   // Sizes and strides @rank 1: static stride 1, dynamic size unchanged from source memref.
-  // CHECK: %[[SZ1:.*]] = llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
   // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: llvm.insertvalue %[[C1]], %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
@@ -482,7 +484,9 @@ func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) {
   // CHECK: %[[MUL:.*]] = llvm.mul %[[C1_2]], %[[ST0]]  : i64
   // CHECK: llvm.insertvalue %[[C3]], %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: llvm.insertvalue %[[MUL]], %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-  %1 = memref.subview %0[2][3][1]: memref<5x?xf32> to memref<3x?xf32, offset: ?, strides: [?, 1]>
+  %c0 = arith.constant 1 : index
+  %d0 = memref.dim %0, %c0 : memref<5x?xf32>
+  %1 = memref.subview %0[2, 0][3, %d0][1, 1]: memref<5x?xf32> to memref<3x?xf32, offset: ?, strides: [?, 1]>
   return
 }
 
@@ -506,7 +510,7 @@ func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) {
   // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
   // CHECK: llvm.insertvalue %[[C3]], %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-  %1 = memref.subview %0[1][1][1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]>
+  %1 = memref.subview %0[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]>
 
   return
 }
index ab0be6b..3b3e64d 100644 (file)
@@ -17,17 +17,17 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
 //      CHECK-1D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32>
 //
 //      CHECK-1D: vector.transfer_read {{.*}} : memref<8x16xf32, #{{.*}}>, vector<8x16xf32>
-//      CHECK-1D: vector.transfer_write {{.*}} : vector<8x16xf32>, memref<8x16xf32, #{{.*}}>
+//      CHECK-1D: vector.transfer_write {{.*}} : vector<8x16xf32>, memref<8x16xf32>
 //      CHECK-1D: vector.transfer_read {{.*}} : memref<16x12xf32, #{{.*}}>, vector<16x12xf32>
-//      CHECK-1D: vector.transfer_write {{.*}} : vector<16x12xf32>, memref<16x12xf32, #{{.*}}>
+//      CHECK-1D: vector.transfer_write {{.*}} : vector<16x12xf32>, memref<16x12xf32>
 //      CHECK-1D: vector.transfer_read {{.*}} : memref<8x12xf32, #{{.*}}>, vector<8x12xf32>
-//      CHECK-1D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32, #{{.*}}>
+//      CHECK-1D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32>
 //
 //      CHECK-1D: vector.contract
 // CHECK-1D-SAME:   iterator_types = ["parallel", "parallel", "reduction"]
 // CHECK-1D-SAME:   : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
 //
-//      CHECK-1D: vector.transfer_read {{.*}} : memref<8x12xf32, #{{.*}}>, vector<8x12xf32>
+//      CHECK-1D: vector.transfer_read {{.*}} : memref<8x12xf32>, vector<8x12xf32>
 //      CHECK-1D: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32, #{{.*}}>
 
 // CHECK-2D-LABEL:func @matmul
index 80282c2..39f9847 100644 (file)
@@ -2,13 +2,13 @@
 
 // CHECK-LABEL: func @subview_of_size_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
-//       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
 //       CHECK:   %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
 //       CHECK:   return %[[M]] : memref<16x32xi8, #{{.*}}>
 func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) ->
   memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
   %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
-  %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
+  %1 = memref.subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] :
     memref<?x?x16x32xi8> to
     memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
   return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
@@ -450,3 +450,52 @@ func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) {
   // CHECK-NEXT: return [[C2]]
   return %rank_0 : index
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0 * 42 + d1)>
+func @fold_no_op_subview(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, #map> {
+  %0 = memref.subview %arg0[0, 0] [20, 42] [1, 1] : memref<20x42xf32> to memref<20x42xf32, #map>
+  return %0 : memref<20x42xf32, #map>
+}
+// CHECK-LABEL: func @fold_no_op_subview(
+//       CHECK:   %[[ARG0:.+]]: memref<20x42xf32>)
+//       CHECK:   %[[CAST:.+]] = memref.cast %[[ARG0]]
+//       CHECK:   return %[[CAST]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0 * 42 + d1 + 1)>
+func @no_fold_subview_with_non_zero_offset(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, #map> {
+  %0 = memref.subview %arg0[0, 1] [20, 42] [1, 1] : memref<20x42xf32> to memref<20x42xf32, #map>
+  return %0 : memref<20x42xf32, #map>
+}
+// CHECK-LABEL: func @no_fold_subview_with_non_zero_offset(
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview
+//       CHECK:    return %[[SUBVIEW]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0 * 42 + d1 * 2)>
+func @no_fold_subview_with_non_unit_stride(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, #map> {
+  %0 = memref.subview %arg0[0, 0] [20, 42] [1, 2] : memref<20x42xf32> to memref<20x42xf32, #map>
+  return %0 : memref<20x42xf32, #map>
+}
+// CHECK-LABEL: func @no_fold_subview_with_non_unit_stride(
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview
+//       CHECK:    return %[[SUBVIEW]]
+
+// -----
+
+#map = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + d1 + s0)>
+func @no_fold_dynamic_no_op_subview(%arg0 : memref<?x?xf32>) -> memref<?x?xf32, #map> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = memref.dim %arg0, %c0 : memref<?x?xf32>
+  %1 = memref.dim %arg0, %c1 : memref<?x?xf32>
+  %2 = memref.subview %arg0[0, 0] [%0, %1] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map>
+  return %2 : memref<?x?xf32, #map>
+}
+// CHECK-LABEL: func @no_fold_dynamic_no_op_subview(
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview
+//       CHECK:    return %[[SUBVIEW]]
index 97d9db8..5cf3270 100644 (file)
@@ -149,7 +149,7 @@ func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off
 // -----
 
 func @memref_reinterpret_cast_too_many_offsets(%in: memref<?xf32>) {
-  // expected-error @+1 {{expected <= 1 offset values}}
+  // expected-error @+1 {{expected 1 offset values}}
   %out = memref.reinterpret_cast %in to
            offset: [0, 0], sizes: [10, 10], strides: [10, 1]
            : memref<?xf32> to memref<10x10xf32, offset: 0, strides: [10, 1]>
@@ -580,7 +580,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
   %0 = memref.alloc() : memref<8x16x4xf32>
-  // expected-error@+1 {{expected <= 3 offset values}}
+  // expected-error@+1 {{expected 3 offset values}}
   %1 = memref.subview %0[%arg0, %arg1, 0, 0][%arg2, 0, 0, 0][1, 1, 1, 1]
     : memref<8x16x4xf32> to
       memref<8x?x4xf32, offset: 0, strides:[?, ?, 4]>
@@ -840,3 +840,11 @@ func @rank(%0: f32) {
   "memref.rank"(%0): (f32)->index
   return
 }
+
+// -----
+
+#map = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (s0 + d0 * s1 + d1 * s2 + d2 * s3)>
+func @illegal_num_offsets(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) {
+  // expected-error@+1 {{expected 3 offset values}}
+  %0 = memref.subview %arg0[0, 0] [%arg1, %arg2] [1, 1] : memref<?x?x?xf32> to memref<?x?x?xf32, #map>
+}
index dbfd132..3bfc62d 100644 (file)
@@ -109,12 +109,12 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
 
   /// Subview with only leading operands.
   %24 = memref.alloc() : memref<5x3xf32>
-  // CHECK: memref.subview %{{.*}}[2] [3] [1] : memref<5x3xf32> to memref<3x3xf32, #[[$SUBVIEW_MAP9]]>
-  %25 = memref.subview %24[2][3][1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]>
+  // CHECK: memref.subview %{{.*}}[2, 0] [3, 3] [1, 1] : memref<5x3xf32> to memref<3x3xf32, #[[$SUBVIEW_MAP9]]>
+  %25 = memref.subview %24[2, 0][3, 3][1, 1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]>
 
   /// Rank-reducing subview with only leading operands.
-  // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<5x3xf32> to memref<3xf32, #[[$SUBVIEW_MAP10]]>
-  %26 = memref.subview %24[1][1][1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]>
+  // CHECK: memref.subview %{{.*}}[1, 0] [1, 3] [1, 1] : memref<5x3xf32> to memref<3xf32, #[[$SUBVIEW_MAP10]]>
+  %26 = memref.subview %24[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]>
 
   // Corner-case of 0-D rank-reducing subview with an offset.
   // CHECK: memref.subview %{{.*}}[1, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref<f32, #[[$SUBVIEW_MAP11]]>
index 2c18fe4..82f880d 100644 (file)
@@ -395,13 +395,13 @@ func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x3
 
 // CHECK-LABEL: func @rank_reducing_tensor_of_cast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
-//       CHECK:   %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
+//       CHECK:   %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
 // Tensor cast is moved after slice and then gets canonicalized away.
 //   CHECK-NOT:   tensor.cast
 //       CHECK:   return %[[S]] : tensor<16x32xi8>
 func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> {
   %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
-  %1 = tensor.extract_slice %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
+  %1 = tensor.extract_slice %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
   return %1 : tensor<16x32xi8>
 }
 
@@ -410,7 +410,7 @@ func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32x
 // CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
 //  CHECK-SAME:   %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
 //  CHECK-SAME:   %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
-//       CHECK:   %[[S:.+]] = tensor.insert_slice %[[A]] into %[[B]][0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8>
+//       CHECK:   %[[S:.+]] = tensor.insert_slice %[[A]] into %[[B]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8>
 // Tensor cast is folded away.
 //   CHECK-NOT:   tensor.cast
 //       CHECK:   return %[[S]] : tensor<4x6x16x32xi8>
@@ -418,7 +418,7 @@ func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x
   %c0 = arith.constant 0: index
   %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
   %sz = tensor.dim %cast, %c0: tensor<?x32xi8>
-  %res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, %sz] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
+  %res = tensor.insert_slice %cast into %b[0, 1, 0, 0] [1, 1, %sz, 32] [1, 1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
   return %res : tensor<4x6x16x32xi8>
 }
 
index ece2f54..8cdab35 100644 (file)
@@ -300,3 +300,20 @@ func @rank(%0: f32) {
   "tensor.rank"(%0): (f32)->index
   return
 }
+
+// -----
+
+func @illegal_num_offsets(%arg0 : tensor<?x?x?xf32>, %arg1 : index, %arg2 : index) {
+  // expected-error@+1 {{expected 3 offset values}}
+  %0 = tensor.extract_slice %arg0[0, 0] [%arg1, %arg2] [1, 1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+  return
+}
+
+// -----
+
+func @illegal_num_offsets(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?x?xf32>,
+    %arg2 : index, %arg3 : index) {
+  // expected-error@+1 {{expected 3 offset values}}
+  %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return
+}
index a3d34a6..6079b1a 100644 (file)
@@ -1,7 +1,5 @@
 // RUN: mlir-opt %s -test-vector-transfer-drop-unit-dims-patterns -split-input-file | FileCheck %s
 
-// -----
-
 func @transfer_read_rank_reducing(
       %arg : memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>) -> vector<3x2xi8> {
     %c0 = arith.constant 0 : index
index 477576c..5fb24c6 100644 (file)
@@ -13,7 +13,7 @@ func @main() {
   %0 = memref.get_global @__constant_5x3xf32 : memref<5x3xf32>
 
   /// Subview with only leading operands.
-  %1 = memref.subview %0[2][3][1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]>
+  %1 = memref.subview %0[2, 0][3, 3][1, 1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]>
   %unranked = memref.cast %1 : memref<3x3xf32, offset: 6, strides: [3, 1]> to memref<*xf32>
   call @print_memref_f32(%unranked) : (memref<*xf32>) -> ()
 
@@ -50,7 +50,7 @@ func @main() {
   // CHECK-NEXT: [2,  5,  8,  11,  14]
 
   /// Rank-reducing subview with only leading operands.
-  %4 = memref.subview %0[1][1][1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]>
+  %4 = memref.subview %0[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]>
   %unranked4 = memref.cast %4 : memref<3xf32, offset: 3, strides: [1]> to memref<*xf32>
   call @print_memref_f32(%unranked4) : (memref<*xf32>) -> ()
   //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}