[mlir][tensor] InsertSliceOp verification.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 29 Nov 2021 16:22:45 +0000 (16:22 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 30 Nov 2021 20:37:06 +0000 (20:37 +0000)
This revision reintroduces tensor.insert_slice verification which seems
to have vanished over time: a verifier was initially introduced in cf9503c1b752062d9abfb2c7922a50574d9c5de4
but for some reason the invalid.mlir was not properly updated; as time passed the verifier was not called anymore and later the code was deleted.

As a consequence, a non-negligible portion of tests has run astray using invalid
tensor.insert_slice semantics and needed to be fixed.

Also, extract isRankReducedType from TensorOps for better reuse
Originally, this facility was used by both tensor and memref forms but
it got copied around as dialects were split.

Differential Revision: https://reviews.llvm.org/D114715

18 files changed:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/include/mlir/IR/BuiltinTypes.h
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Tensor/ops.mlir
mlir/test/IR/core-ops.mlir

index 2151606f4bfb3a5f85262c13cc9c7aba2518f8f2..5442fed96dd1af7249f55892a9915776eb21059a 100644 (file)
@@ -184,11 +184,14 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
     ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
     a dynamic value.
 
-    After buffer-allocation, the "extract_slice" op is expected to lower into a
-    "subview" op.
+    After buffer allocation, the "extract_slice" op is expected to lower into a
+    memref.subview op.
 
     An extract_slice operation may additionally reduce the rank of the resulting
     tensor by removing dimensions that are statically known to be of size 1.
+    This rank-reduction behavior is not required by the op semantics: this
+    flexibility allows to progressively drop unit dimensions while lowering
+    between different flavors of ops on that operate on tensors.
 
     Example:
 
@@ -196,8 +199,8 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
     // Rank-reducing extract_slice.
     %1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] :
       tensor<8x16x4xf32> to tensor<16x4xf32>
-    %3 = tensor.extract_slice %2[3, 4, 2][1, 6, 3][1, 1, 1] :
-      tensor<8x16x4xf32> to tensor<6x3xf32>
+    %3 = tensor.extract_slice %2[%o0, 4, %o2][1, %sz1, 1][1, %st1, 1] :
+      tensor<8x16x4xf32> to tensor<1x?xf32>
     ```
   }];
 
@@ -257,24 +260,28 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
     /// An extract_slice result type can be fully inferred from the source type
     /// and the static representation of offsets, sizes and strides. Special
     /// sentinels encode the dynamic case.
-    static Type inferResultType(RankedTensorType sourceRankedTensorType,
-                                ArrayRef<int64_t> staticOffsets,
-                                ArrayRef<int64_t> staticSizes,
-                                ArrayRef<int64_t> staticStrides);
-    static Type inferResultType(RankedTensorType sourceRankedTensorType,
-                                ArrayRef<OpFoldResult> staticOffsets,
-                                ArrayRef<OpFoldResult> staticSizes,
-                                ArrayRef<OpFoldResult> staticStrides);
-    static Type inferRankReducedResultType(unsigned resultRank,
-                                           RankedTensorType sourceRankedTensorType,
-                                           ArrayRef<int64_t> staticOffsets,
-                                           ArrayRef<int64_t> staticSizes,
-                                           ArrayRef<int64_t> staticStrides);
-    static Type inferRankReducedResultType(unsigned resultRank,
-                                           RankedTensorType sourceRankedTensorType,
-                                           ArrayRef<OpFoldResult> staticOffsets,
-                                           ArrayRef<OpFoldResult> staticSizes,
-                                           ArrayRef<OpFoldResult> staticStrides);
+    static RankedTensorType inferResultType(
+      RankedTensorType sourceRankedTensorType,
+      ArrayRef<int64_t> staticOffsets,
+      ArrayRef<int64_t> staticSizes,
+      ArrayRef<int64_t> staticStrides);
+    static RankedTensorType inferResultType(
+      RankedTensorType sourceRankedTensorType,
+      ArrayRef<OpFoldResult> staticOffsets,
+      ArrayRef<OpFoldResult> staticSizes,
+      ArrayRef<OpFoldResult> staticStrides);
+    static RankedTensorType inferRankReducedResultType(
+      unsigned resultRank,
+      RankedTensorType sourceRankedTensorType,
+      ArrayRef<int64_t> staticOffsets,
+      ArrayRef<int64_t> staticSizes,
+      ArrayRef<int64_t> staticStrides);
+    static RankedTensorType inferRankReducedResultType(
+      unsigned resultRank,
+      RankedTensorType sourceRankedTensorType,
+      ArrayRef<OpFoldResult> staticOffsets,
+      ArrayRef<OpFoldResult> staticSizes,
+      ArrayRef<OpFoldResult> staticStrides);
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
@@ -469,8 +476,27 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
     ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
     a dynamic value.
 
-    After buffer-allocation, the "insert_slice" op is expected to become an
-    in-place buffer update.
+    After buffer allocation, the "insert_slice" op is expected to lower into a
+    memref.subview op.
+
+    An insert_slice operation may additionally specify insertion into a tensor
+    of higher rank than the source tensor, along dimensions that are statically
+    known to be of size 1.
+    This rank-altering behavior is not required by the op semantics: this
+    flexibility allows to progressively drop unit dimensions while lowering
+    between different flavors of ops on that operate on tensors.
+    The rank-altering behavior of tensor.insert_slice matches the rank-reducing
+    behavior of tensor.extract_slice.
+
+    Example:
+
+    ```
+    // Rank-reducing extract_slice.
+    %1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] :
+      tensor<16x4xf32> into tensor<8x16x4xf32>
+    %3 = tensor.insert_slice %tt into %2[%o0, 4, %o2][1, %sz1, 1][1, %st1, 1] :
+      tensor<1x?xf32> into tensor<8x16x4xf32>
+    ```
   }];
 
   let arguments = (ins
@@ -493,8 +519,6 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
     attr-dict `:` type($source) `into` type($dest)
   }];
 
-  let verifier = ?;
-
   let builders = [
     // Build a InsertSliceOp with mixed static and dynamic entries.
     OpBuilder<(ins "Value":$source, "Value":$dest,
index 5838f1d1fb241916a878e1c5ad675769bc0d1fc2..bf5d0a8bd1bf8b2bd0e8f9002c43f40ed12377e2 100644 (file)
 
 namespace mlir {
 
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// Helper function to dispatch an OpFoldResult into `staticVec` if:
+///   a) it is an IntegerAttr
+/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
+/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
 /// come from an AttrSizedOperandSegments trait.
 void dispatchIndexOpFoldResult(OpFoldResult ofr,
@@ -31,11 +32,8 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
                                SmallVectorImpl<int64_t> &staticVec,
                                int64_t sentinel);
 
-/// Helper function to dispatch multiple OpFoldResults into either the
-/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs).
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
+/// Helper function to dispatch multiple OpFoldResults according to the behavior
+/// of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single OpFoldResult.
 void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
                                 SmallVectorImpl<Value> &dynamicVec,
                                 SmallVectorImpl<int64_t> &staticVec,
index f3d2c24073dc63a1c853f788eb24818d04e1b330..10d8a5847ebd2dac534a6c874301fefa18f75ed9 100644 (file)
@@ -369,6 +369,25 @@ llvm::Optional<llvm::SmallDenseSet<unsigned>>
 computeRankReductionMask(ArrayRef<int64_t> originalShape,
                          ArrayRef<int64_t> reducedShape);
 
+/// Enum that captures information related to verifier error conditions on
+/// slice insert/extract type of ops.
+enum class SliceVerificationResult {
+  Success,
+  RankTooLarge,
+  SizeMismatch,
+  ElemTypeMismatch,
+  // Error codes to ops with a memory space and a layout annotation.
+  MemSpaceMismatch,
+  LayoutMismatch
+};
+
+/// Check if `originalType` can be rank reduced to `candidateReducedType` type
+/// by dropping some dimensions with static size `1`.
+/// Return `SliceVerificationResult::Success` on success or an appropriate error
+/// code.
+SliceVerificationResult isRankReducedType(ShapedType originalType,
+                                          ShapedType candidateReducedType);
+
 //===----------------------------------------------------------------------===//
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
index 77cf563abe1a11743c2a51666cf51c29715d95fd..7961638a4661bf53c690e8010bd2415c00772e78 100644 (file)
@@ -2248,8 +2248,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
 
     Location loc = op.getLoc();
     int axis = op.axis();
-    Value axisValue =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
+    Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(axis));
     int rank = resultType.getRank();
     SmallVector<Value, 3> offsets, sizes, strides;
     sizes.reserve(rank);
@@ -2257,31 +2257,41 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
     offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
 
     for (int i = 0; i < rank; ++i) {
-      sizes.push_back(
-          rewriter.create<tensor::DimOp>(loc, adaptor.getOperands()[0], i));
+      sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
+          loc, adaptor.getOperands()[0], i));
     }
 
     Value resultDimSize = sizes[axis];
     for (auto arg : adaptor.getOperands().drop_front()) {
-      auto size = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
-      resultDimSize = rewriter.create<arith::AddIOp>(loc, resultDimSize, size);
+      auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
+      resultDimSize =
+          rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
     }
     sizes[axis] = resultDimSize;
 
     Value init = rewriter.create<linalg::InitTensorOp>(
         loc, resultType.getShape(), resultType.getElementType());
 
-    Value zeroVal = rewriter.create<arith::ConstantOp>(
+    Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(resultType.getElementType()));
     Value result =
         rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0);
 
+    auto toOpFoldResult = [](Value v) -> OpFoldResult {
+      auto op = v.getDefiningOp<arith::ConstantIndexOp>();
+      if (!op)
+        return v;
+      return op.getValue();
+    };
     for (auto arg : adaptor.getOperands()) {
-      sizes[axis] = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
-      result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
-                                                      sizes, strides);
+      sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
+      result = rewriter.createOrFold<tensor::InsertSliceOp>(
+          loc, arg, result,
+          llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
+          llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
+          llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
       offsets[axis] =
-          rewriter.create<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
+          rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
     }
     rewriter.replaceOp(op, result);
     return success();
index 9ef930cb204c18e5e3ad4b9624666c043b0db8a8..36828eabd59f72b9831237aa59809be8dbdde1f3 100644 (file)
@@ -835,16 +835,14 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 // InitTensorOp
 //===----------------------------------------------------------------------===//
+
 void InitTensorOp::build(OpBuilder &b, OperationState &result,
                          ArrayRef<OpFoldResult> sizes, Type elementType,
                          ArrayRef<NamedAttribute> attrs) {
-  unsigned rank = sizes.size();
   SmallVector<Value, 4> dynamicSizes;
   SmallVector<int64_t, 4> staticSizes;
-  for (unsigned i = 0; i < rank; ++i) {
-    dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes,
-                              ShapedType::kDynamicSize);
-  }
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
+                             ShapedType::kDynamicSize);
   auto resultType = RankedTensorType ::get(staticSizes, elementType);
   build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
   result.addAttributes(attrs);
@@ -1127,19 +1125,16 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
                         ArrayRef<NamedAttribute> attrs) {
   assert(resultType.isa<RankedTensorType>());
   auto sourceType = source.getType().cast<RankedTensorType>();
-  unsigned rank = sourceType.getRank();
   SmallVector<Value, 4> dynamicLow, dynamicHigh;
   SmallVector<int64_t, 4> staticLow, staticHigh;
-  for (unsigned i = 0; i < rank; ++i) {
-    // staticLow and staticHigh have full information of the padding config.
-    // This will grow staticLow and staticHigh with 1 value. If the config is
-    // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
-    // value as well.
-    dispatchIndexOpFoldResult(low[i], dynamicLow, staticLow,
-                              ShapedType::kDynamicSize);
-    dispatchIndexOpFoldResult(high[i], dynamicHigh, staticHigh,
-                              ShapedType::kDynamicSize);
-  }
+  // staticLow and staticHigh have full information of the padding config.
+  // This will grow staticLow and staticHigh with 1 value. If the config is
+  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
+  // value as well.
+  dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
+                             ShapedType::kDynamicSize);
+  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
+                             ShapedType::kDynamicSize);
   if (!resultType) {
     resultType =
         PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
index b9bd01b439a9dc7c48009d3422e3f6df1014c25e..938197df59c220446c4bd75e5fb68b2a57a726ae 100644 (file)
@@ -504,11 +504,13 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
   return numOccurences;
 }
 
-/// Given the type of the un-rank reduced subview result type and the
-/// rank-reduced result type, computes the dropped dimensions. This accounts for
-/// cases where there are multiple unit-dims, but only a subset of those are
-/// dropped. For MemRefTypes these can be disambiguated using the strides. If a
-/// dimension is dropped the stride must be dropped too.
+/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
+/// to be a subset of `originalType` with some `1` entries erased, return the
+/// set of indices that specifies which of the entries of `originalShape` are
+/// dropped to obtain `reducedShape`.
+/// This accounts for cases where there are multiple unit-dims, but only a
+/// subset of those are dropped. For MemRefTypes these can be disambiguated
+/// using the strides. If a dimension is dropped the stride must be dropped too.
 static llvm::Optional<llvm::SmallDenseSet<unsigned>>
 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
                                ArrayRef<OpFoldResult> sizes) {
@@ -1548,8 +1550,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
                              staticStrides, ShapedType::kDynamicStrideOrOffset);
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
-                                    staticSizes, staticStrides)
-      .cast<MemRefType>();
+                                    staticSizes, staticStrides);
 }
 
 Type SubViewOp::inferRankReducedResultType(
@@ -1706,88 +1707,58 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
 /// For ViewLikeOpInterface.
 Value SubViewOp::getViewSource() { return source(); }
 
-enum SubViewVerificationResult {
-  Success,
-  RankTooLarge,
-  SizeMismatch,
-  ElemTypeMismatch,
-  MemSpaceMismatch,
-  AffineMapMismatch
-};
-
 /// Checks if `original` Type type can be rank reduced to `reduced` type.
 /// This function is slight variant of `is subsequence` algorithm where
 /// not matching dimension must be 1.
-static SubViewVerificationResult
-isRankReducedType(Type originalType, Type candidateReducedType,
-                  ArrayRef<OpFoldResult> sizes, std::string *errMsg = nullptr) {
-  if (originalType == candidateReducedType)
-    return SubViewVerificationResult::Success;
-  if (!originalType.isa<MemRefType>())
-    return SubViewVerificationResult::Success;
-  if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
-    return SubViewVerificationResult::Success;
-
-  ShapedType originalShapedType = originalType.cast<ShapedType>();
-  ShapedType candidateReducedShapedType =
-      candidateReducedType.cast<ShapedType>();
-
-  // Rank and size logic is valid for all ShapedTypes.
-  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
-  ArrayRef<int64_t> candidateReducedShape =
-      candidateReducedShapedType.getShape();
-  unsigned originalRank = originalShape.size(),
-           candidateReducedRank = candidateReducedShape.size();
-  if (candidateReducedRank > originalRank)
-    return SubViewVerificationResult::RankTooLarge;
+static SliceVerificationResult
+isRankReducedMemRefType(MemRefType originalType,
+                        MemRefType candidatecandidateReducedType,
+                        ArrayRef<OpFoldResult> sizes) {
+  auto partialRes =
+      isRankReducedType(originalType, candidatecandidateReducedType);
+  if (partialRes != SliceVerificationResult::Success)
+    return partialRes;
 
   MemRefType original = originalType.cast<MemRefType>();
-  MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
+  MemRefType candidateReduced =
+      candidatecandidateReducedType.cast<MemRefType>();
 
   auto optionalUnusedDimsMask =
       computeMemRefRankReductionMask(original, candidateReduced, sizes);
 
   // Sizes cannot be matched in case empty vector is returned.
   if (!optionalUnusedDimsMask.hasValue())
-    return SubViewVerificationResult::SizeMismatch;
+    return SliceVerificationResult::LayoutMismatch;
 
-  if (originalShapedType.getElementType() !=
-      candidateReducedShapedType.getElementType())
-    return SubViewVerificationResult::ElemTypeMismatch;
-
-  // Strided layout logic is relevant for MemRefType only.
   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
-    return SubViewVerificationResult::MemSpaceMismatch;
-  return SubViewVerificationResult::Success;
+    return SliceVerificationResult::MemSpaceMismatch;
+
+  return SliceVerificationResult::Success;
 }
 
 template <typename OpTy>
-static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
-                                            OpTy op, Type expectedType,
-                                            StringRef errMsg = "") {
+static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
+                                            OpTy op, Type expectedType) {
   auto memrefType = expectedType.cast<ShapedType>();
   switch (result) {
-  case SubViewVerificationResult::Success:
+  case SliceVerificationResult::Success:
     return success();
-  case SubViewVerificationResult::RankTooLarge:
+  case SliceVerificationResult::RankTooLarge:
     return op.emitError("expected result rank to be smaller or equal to ")
-           << "the source rank. " << errMsg;
-  case SubViewVerificationResult::SizeMismatch:
+           << "the source rank. ";
+  case SliceVerificationResult::SizeMismatch:
     return op.emitError("expected result type to be ")
            << expectedType
-           << " or a rank-reduced version. (mismatch of result sizes) "
-           << errMsg;
-  case SubViewVerificationResult::ElemTypeMismatch:
+           << " or a rank-reduced version. (mismatch of result sizes) ";
+  case SliceVerificationResult::ElemTypeMismatch:
     return op.emitError("expected result element type to be ")
-           << memrefType.getElementType() << errMsg;
-  case SubViewVerificationResult::MemSpaceMismatch:
-    return op.emitError("expected result and source memory spaces to match.")
-           << errMsg;
-  case SubViewVerificationResult::AffineMapMismatch:
+           << memrefType.getElementType();
+  case SliceVerificationResult::MemSpaceMismatch:
+    return op.emitError("expected result and source memory spaces to match.");
+  case SliceVerificationResult::LayoutMismatch:
     return op.emitError("expected result type to be ")
            << expectedType
-           << " or a rank-reduced version. (mismatch of result affine map) "
-           << errMsg;
+           << " or a rank-reduced version. (mismatch of result layout) ";
   }
   llvm_unreachable("unexpected subview verification result");
 }
@@ -1813,10 +1784,9 @@ static LogicalResult verify(SubViewOp op) {
       extractFromI64ArrayAttr(op.static_sizes()),
       extractFromI64ArrayAttr(op.static_strides()));
 
-  std::string errMsg;
-  auto result =
-      isRankReducedType(expectedType, subViewType, op.getMixedSizes(), &errMsg);
-  return produceSubViewErrorMsg(result, op, expectedType, errMsg);
+  auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
+                                        subViewType, op.getMixedSizes());
+  return produceSubViewErrorMsg(result, op, expectedType);
 }
 
 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
index 0eff26bc86de47baa2019004129a3b5b5cbcec3d..7f1bd74cd37aa346318bb4a29c1bda7cf7f04436 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -655,10 +656,11 @@ static LogicalResult verify(ReshapeOp op) {
 /// An extract_slice op result type can be fully inferred from the source type
 /// and the static representation of offsets, sizes and strides. Special
 /// sentinels encode the dynamic case.
-Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
-                                     ArrayRef<int64_t> leadingStaticOffsets,
-                                     ArrayRef<int64_t> leadingStaticSizes,
-                                     ArrayRef<int64_t> leadingStaticStrides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
+                                ArrayRef<int64_t> leadingStaticOffsets,
+                                ArrayRef<int64_t> leadingStaticSizes,
+                                ArrayRef<int64_t> leadingStaticStrides) {
   // An extract_slice op may specify only a leading subset of offset/sizes/
   // strides in which case we complete with offset=0, sizes from memref type and
   // strides=1.
@@ -673,11 +675,11 @@ Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
                                sourceRankedTensorType.getElementType());
 }
 
-Type ExtractSliceOp::inferResultType(
-    RankedTensorType sourceRankedTensorType,
-    ArrayRef<OpFoldResult> leadingStaticOffsets,
-    ArrayRef<OpFoldResult> leadingStaticSizes,
-    ArrayRef<OpFoldResult> leadingStaticStrides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
+                                ArrayRef<OpFoldResult> leadingStaticOffsets,
+                                ArrayRef<OpFoldResult> leadingStaticSizes,
+                                ArrayRef<OpFoldResult> leadingStaticStrides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
@@ -693,7 +695,7 @@ Type ExtractSliceOp::inferResultType(
 /// An extract_slice op result type can be fully inferred from the source type
 /// and the static representation of offsets, sizes and strides. Special
 /// sentinels encode the dynamic case.
-Type ExtractSliceOp::inferRankReducedResultType(
+RankedTensorType ExtractSliceOp::inferRankReducedResultType(
     unsigned resultRank, RankedTensorType sourceRankedTensorType,
     ArrayRef<int64_t> leadingStaticOffsets,
     ArrayRef<int64_t> leadingStaticSizes,
@@ -717,7 +719,7 @@ Type ExtractSliceOp::inferRankReducedResultType(
   return inferredType;
 }
 
-Type ExtractSliceOp::inferRankReducedResultType(
+RankedTensorType ExtractSliceOp::inferRankReducedResultType(
     unsigned resultRank, RankedTensorType sourceRankedTensorType,
     ArrayRef<OpFoldResult> leadingStaticOffsets,
     ArrayRef<OpFoldResult> leadingStaticSizes,
@@ -746,10 +748,12 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+
                              ShapedType::kDynamicStrideOrOffset);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
                              ShapedType::kDynamicSize);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+
                              ShapedType::kDynamicStrideOrOffset);
   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
   // Structuring implementation this way avoids duplication between builders.
@@ -797,89 +801,35 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
 }
 
-enum SliceVerificationResult {
-  Success,
-  RankTooLarge,
-  SizeMismatch,
-  ElemTypeMismatch,
-};
-
-/// Checks if `original` Type type can be rank reduced to `reduced` type.
-/// This function is slight variant of `is subsequence` algorithm where
-/// not matching dimension must be 1.
-static SliceVerificationResult
-isRankReducedType(Type originalType, Type candidateReducedType,
-                  std::string *errMsg = nullptr) {
-  if (originalType == candidateReducedType)
-    return SliceVerificationResult::Success;
-  if (!originalType.isa<RankedTensorType>())
-    return SliceVerificationResult::Success;
-  if (originalType.isa<RankedTensorType>() &&
-      !candidateReducedType.isa<RankedTensorType>())
-    return SliceVerificationResult::Success;
-
-  ShapedType originalShapedType = originalType.cast<ShapedType>();
-  ShapedType candidateReducedShapedType =
-      candidateReducedType.cast<ShapedType>();
-
-  // Rank and size logic is valid for all ShapedTypes.
-  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
-  ArrayRef<int64_t> candidateReducedShape =
-      candidateReducedShapedType.getShape();
-  unsigned originalRank = originalShape.size(),
-           candidateReducedRank = candidateReducedShape.size();
-  if (candidateReducedRank > originalRank)
-    return SliceVerificationResult::RankTooLarge;
-
-  auto optionalUnusedDimsMask =
-      computeRankReductionMask(originalShape, candidateReducedShape);
-
-  // Sizes cannot be matched in case empty vector is returned.
-  if (!optionalUnusedDimsMask.hasValue())
-    return SliceVerificationResult::SizeMismatch;
-
-  if (originalShapedType.getElementType() !=
-      candidateReducedShapedType.getElementType())
-    return SliceVerificationResult::ElemTypeMismatch;
-
-  // We are done for the tensor case.
-  if (originalType.isa<RankedTensorType>())
-    return SliceVerificationResult::Success;
-
-  return SliceVerificationResult::Success;
-}
-
 template <typename OpTy>
 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
-                                          OpTy op, Type expectedType,
-                                          StringRef errMsg = "") {
+                                          OpTy op, Type expectedType) {
   auto memrefType = expectedType.cast<ShapedType>();
   switch (result) {
   case SliceVerificationResult::Success:
     return success();
   case SliceVerificationResult::RankTooLarge:
-    return op.emitError("expected result rank to be smaller or equal to ")
-           << "the source rank. " << errMsg;
+    return op.emitError("expected rank to be smaller or equal to ")
+           << "the other rank. ";
   case SliceVerificationResult::SizeMismatch:
-    return op.emitError("expected result type to be ")
-           << expectedType
-           << " or a rank-reduced version. (mismatch of result sizes) "
-           << errMsg;
+    return op.emitError("expected type to be ")
+           << expectedType << " or a rank-reduced version. (size mismatch) ";
   case SliceVerificationResult::ElemTypeMismatch:
-    return op.emitError("expected result element type to be ")
-           << memrefType.getElementType() << errMsg;
+    return op.emitError("expected element type to be ")
+           << memrefType.getElementType();
+  default:
+    llvm_unreachable("unexpected extract_slice op verification result");
   }
-  llvm_unreachable("unexpected extract_slice op verification result");
 }
 
 /// Verifier for ExtractSliceOp.
 static LogicalResult verify(ExtractSliceOp op) {
   // Verify result type against inferred type.
-  auto expectedType = ExtractSliceOp::inferResultType(
-      op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
-      extractFromI64ArrayAttr(op.static_sizes()),
-      extractFromI64ArrayAttr(op.static_strides()));
-  auto result = isRankReducedType(expectedType, op.getType());
+  auto expectedType =
+      ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(),
+                                      op.getMixedSizes(), op.getMixedStrides());
+  auto result =
+      isRankReducedType(expectedType.cast<ShapedType>(), op.getType());
   return produceSliceErrorMsg(result, op, expectedType);
 }
 
@@ -1104,10 +1054,12 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+
                              ShapedType::kDynamicStrideOrOffset);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
                              ShapedType::kDynamicSize);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+
                              ShapedType::kDynamicStrideOrOffset);
   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
@@ -1128,6 +1080,19 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
 }
 
+/// Verifier for InsertSliceOp.
+static LogicalResult verify(InsertSliceOp op) {
+  // insert_slice is the inverse of extract_slice, use the same type inference.
+  auto expectedType = ExtractSliceOp::inferRankReducedResultType(
+      op.getSourceType().getRank(), op.getType(),
+      extractFromI64ArrayAttr(op.static_offsets()),
+      extractFromI64ArrayAttr(op.static_sizes()),
+      extractFromI64ArrayAttr(op.static_strides()));
+  auto result =
+      isRankReducedType(expectedType.cast<ShapedType>(), op.getSourceType());
+  return produceSliceErrorMsg(result, op, expectedType);
+}
+
 /// If we have two consecutive InsertSliceOp writing to the same slice, we
 /// can mutate the second InsertSliceOp's destination to the first one's.
 ///
@@ -1202,9 +1167,16 @@ public:
     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
 
     // Create the new op in canonical form.
-    rewriter.replaceOpWithNewOp<InsertSliceOp>(
-        insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
+    auto sourceType = ExtractSliceOp::inferRankReducedResultType(
+        insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
         mixedOffsets, mixedSizes, mixedStrides);
+    Value toInsert = insertSliceOp.source();
+    if (sourceType != insertSliceOp.getSourceType())
+      toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
+                                                 sourceType, toInsert);
+    rewriter.replaceOpWithNewOp<InsertSliceOp>(
+        insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes,
+        mixedStrides);
     return success();
   }
 };
index 24d8ac09deff9f6acd988edb39e30ed9032c3cf8..3e50fac6fd3a86c8dea720126cc42e0cf9d8537e 100644 (file)
 
 namespace mlir {
 
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// Helper function to dispatch an OpFoldResult into `staticVec` if:
+///   a) it is an IntegerAttr
+/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
+/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
 /// come from an AttrSizedOperandSegments trait.
 void dispatchIndexOpFoldResult(OpFoldResult ofr,
                                SmallVectorImpl<Value> &dynamicVec,
                                SmallVectorImpl<int64_t> &staticVec,
                                int64_t sentinel) {
-  if (auto v = ofr.dyn_cast<Value>()) {
-    dynamicVec.push_back(v);
-    staticVec.push_back(sentinel);
+  auto v = ofr.dyn_cast<Value>();
+  if (!v) {
+    APInt apInt = ofr.get<Attribute>().cast<IntegerAttr>().getValue();
+    staticVec.push_back(apInt.getSExtValue());
     return;
   }
-  APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
-  staticVec.push_back(apInt.getSExtValue());
+  dynamicVec.push_back(v);
+  staticVec.push_back(sentinel);
 }
 
 void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
index 64dceaaa4480d3cc2bba9292cfb2486d6b6494c2..33ed6b60932d4a4768f1c2d5b2d6f2a253bb657b 100644 (file)
@@ -571,7 +571,7 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
   llvm::SmallDenseSet<unsigned> unusedDims;
   unsigned reducedIdx = 0;
   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
-    // Greedily insert `originalIdx` if no match.
+    // Greedily insert `originalIdx` if match.
     if (reducedIdx < reducedRank &&
         originalShape[originalIdx] == reducedShape[reducedIdx]) {
       reducedIdx++;
@@ -590,6 +590,39 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
   return unusedDims;
 }
 
+SliceVerificationResult
+mlir::isRankReducedType(ShapedType originalType,
+                        ShapedType candidateReducedType) {
+  if (originalType == candidateReducedType)
+    return SliceVerificationResult::Success;
+
+  ShapedType originalShapedType = originalType.cast<ShapedType>();
+  ShapedType candidateReducedShapedType =
+      candidateReducedType.cast<ShapedType>();
+
+  // Rank and size logic is valid for all ShapedTypes.
+  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
+  ArrayRef<int64_t> candidateReducedShape =
+      candidateReducedShapedType.getShape();
+  unsigned originalRank = originalShape.size(),
+           candidateReducedRank = candidateReducedShape.size();
+  if (candidateReducedRank > originalRank)
+    return SliceVerificationResult::RankTooLarge;
+
+  auto optionalUnusedDimsMask =
+      computeRankReductionMask(originalShape, candidateReducedShape);
+
+  // Sizes cannot be matched in case empty vector is returned.
+  if (!optionalUnusedDimsMask.hasValue())
+    return SliceVerificationResult::SizeMismatch;
+
+  if (originalShapedType.getElementType() !=
+      candidateReducedShapedType.getElementType())
+    return SliceVerificationResult::ElemTypeMismatch;
+
+  return SliceVerificationResult::Success;
+}
+
 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
   // Empty attribute is allowed as default memory space.
   if (!memorySpace)
index 1cf88f9bc9709ac0114c468674e5953154909c24..15409702ec197fbfe11c2d10282c785905d4787c 100644 (file)
@@ -820,38 +820,24 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
   // CHECK: [[STRIDE:%.+]]   = arith.constant 1
   // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
   // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
-  // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[IDX0]]
   // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
-  // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[IDX1]]
-  // CHECK: [[ARG1_AXIS:%.+]] = tensor.dim %arg1, [[AXIS]]
-  // CHECK: [[RESULT_AXIS:%.+]] = arith.addi [[ARG0_DIM0]], [[ARG1_AXIS]]
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1]
   // CHECK: [[CST:%.+]] = arith.constant 0.0
   // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]])
-  // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[AXIS]]
-  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
-  // CHECK: [[NEW_OFFSET:%.+]] = arith.addi [[OFFSET]], [[ARG0_DIM0]]
-  // CHECK: [[ARG1_DIM0:%.+]] = tensor.dim %arg1, [[AXIS]]
-  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]]{{\[}}[[NEW_OFFSET]], [[OFFSET]]] {{\[}}[[ARG1_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1]
+  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]][5, 0] [6, 1] [1, 1]
   %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>)  -> (tensor<11x1xf32>)
 
   // CHECK: [[AXIS:%.+]] = arith.constant 1
   // CHECK: [[STRIDE:%.+]]   = arith.constant 1
   // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
   // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
-  // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[IDX0]]
   // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
-  // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[IDX1]]
-  // CHECK: [[ARG1_AXIS:%.+]] = tensor.dim %arg0, [[AXIS]]
-  // CHECK: [[RESULT_AXIS:%.+]] = arith.addi [[ARG0_DIM1]], [[ARG1_AXIS]]
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2]
   // CHECK: [[CST:%.+]] = arith.constant 0.0
   // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]])
-  // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[AXIS]]
-  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
-  // CHECK: [[NEW_OFFSET:%.+]] = arith.addi [[OFFSET]], [[ARG0_DIM1]]
-  // CHECK: [[ARG1_DIM1:%.+]] = tensor.dim %arg0, [[AXIS]]
-  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]]{{\[}}[[OFFSET]], [[NEW_OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG1_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1]
+  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]][0, 1] [5, 1] [1, 1]
   %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>)  -> (tensor<5x2xf32>)
   return
 }
index 68fe343e412e225285357b240ba89cd6d8bcc59c..4a8d2e48162d04c0249a9a391fcb5f8822b9e427 100644 (file)
@@ -428,7 +428,9 @@ func @nested_extract_slice_and_insert(
     %A : tensor<?x?xf32>,
     %B : tensor<?x?xf32> {linalg.inplaceable = true},
     %C : tensor<?x?xf32> {linalg.inplaceable = true},
-    %idx : index)
+    %idx : index,
+    %sz1 : index,
+    %sz2 : index)
   ->  (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
 {
   %f0 = arith.constant 0.0 : f32
@@ -497,9 +499,9 @@ func @nested_extract_slice_and_insert(
   // CHECK-NEXT: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
   %sC = tensor.extract_slice %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-  %ssC = tensor.extract_slice %sC[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
-  %FC = linalg.fill(%f0, %ssC) : f32, tensor<4x4xf32> -> tensor<4x4xf32>
-  %rsC = tensor.insert_slice %FC into %sC[0, 0][12345, 67890][1, 1] : tensor<4x4xf32> into tensor<?x?xf32>
+  %ssC = tensor.extract_slice %sC[0, 0][%sz1, 4][1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
+  %FC = linalg.fill(%f0, %ssC) : f32, tensor<?x4xf32> -> tensor<?x4xf32>
+  %rsC = tensor.insert_slice %FC into %sC[0, 0][%sz2, 4][1, 1] : tensor<?x4xf32> into tensor<?x?xf32>
   %rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 
   return %rA, %rB, %rC: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
index b0cddb26c67725cfea29c1a2e82742d64b283727..20cd8606c9391885fd740655de174a8eb8681dc8 100644 (file)
@@ -592,7 +592,7 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
       linalg.yield %1 : f32
     } -> tensor<4xf32>
 
-    %sum_sub = tensor.insert_slice %acc into %o_[%j][%c4][1]
+    %sum_sub = tensor.insert_slice %acc into %o_[%j][4][1]
       : tensor<4xf32> into tensor<24xf32>
     linalg.yield %sum_sub : tensor<24xf32>
   }
index 861478d7cc171af0ef3065838533b42dac91be2d..73006149a37b5d132a4a1aa150278fac4291661a 100644 (file)
@@ -644,7 +644,7 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index)
 // -----
 
 func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
-  // expected-error@+1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result sizes)}}
+  // expected-error@+1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result layout)}}
   %0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>
   return
 }
@@ -653,7 +653,7 @@ func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg
 
 func @static_stride_to_dynamic_stride(%arg0 : memref<?x?x?xf32>, %arg1 : index,
     %arg2 : index) -> memref<?x?xf32, offset:?, strides: [?, ?]> {
-  // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}}
+  // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result layout)}}
   %0 = memref.subview %arg0[0, 0, 0] [1, %arg1, %arg2] [1, 1, 1] : memref<?x?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
   return %0 : memref<?x?xf32, offset: ?, strides: [?, ?]>
 }
index a07e5688ead36adbec1085575b913ab1c0f502c4..79a692f9d59b14d6ebbc0f67b8ad7eeb1ede8378 100644 (file)
@@ -250,7 +250,7 @@ func @tensor_dim_of_iter_arg(%t : tensor<?x?xf32>) -> index {
 //       CHECK:   scf.for
 //       CHECK:     tensor.dim %[[t]]
 func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
-                                         %t2 : tensor<?x?xf32>) -> index {
+                                         %t2 : tensor<10x10xf32>) -> index {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
@@ -258,9 +258,9 @@ func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
       -> (tensor<?x?xf32>, index) {
     %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
     %2 = tensor.insert_slice %t2 into %arg0[0, 0] [10, 10] [1, 1]
-        : tensor<?x?xf32> into tensor<?x?xf32>
+        : tensor<10x10xf32> into tensor<?x?xf32>
     %3 = tensor.insert_slice %t2 into %2[1, 1] [10, 10] [1, 1]
-        : tensor<?x?xf32> into tensor<?x?xf32>
+        : tensor<10x10xf32> into tensor<?x?xf32>
     scf.yield %3, %dim : tensor<?x?xf32>, index
   }
   return %1 : index
@@ -274,7 +274,7 @@ func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
 //       CHECK:     scf.for
 //       CHECK:       tensor.dim %[[t]]
 func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
-                                        %t2 : tensor<?x?xf32>) -> index {
+                                        %t2 : tensor<10x10xf32>) -> index {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
@@ -284,7 +284,7 @@ func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
         -> (tensor<?x?xf32>, index) {
       %dim = tensor.dim %arg2, %c0 : tensor<?x?xf32>
       %4 = tensor.insert_slice %t2 into %arg2[0, 0] [10, 10] [1, 1]
-          : tensor<?x?xf32> into tensor<?x?xf32>
+          : tensor<10x10xf32> into tensor<?x?xf32>
       scf.yield %4, %dim : tensor<?x?xf32>, index
     }
     scf.yield %2, %3 : tensor<?x?xf32>, index
@@ -292,6 +292,7 @@ func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
   return %1 : index
 }
 
+
 // -----
 
 // A test case that should not canonicalize because the loop is not shape
index 9d9da02c0220f4fc600281a1dcf337ee72ce3df8..1aa4008cf90ec8bd2030693e300cf9fff5d01955 100644 (file)
@@ -348,8 +348,10 @@ func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32x
 //   CHECK-NOT:   tensor.cast
 //       CHECK:   return %[[S]] : tensor<4x6x16x32xi8>
 func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
+  %c0 = arith.constant 0: index
   %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
-  %res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
+  %sz = tensor.dim %cast, %c0: tensor<?x32xi8>
+  %res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, %sz] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
   return %res : tensor<4x6x16x32xi8>
 }
 
@@ -408,9 +410,10 @@ func @rank_reducing_insert_slice_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : i
 }
 // CHECK-LABEL: func @rank_reducing_insert_slice_canonicalize
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>
-//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]]
+//       CHECK:   %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32>
+//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
 //  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
-//  CHECK-SAME:      : tensor<?x?xf32> into tensor<?x?x?xf32>
+//  CHECK-SAME:      : tensor<4x?xf32> into tensor<?x?x?xf32>
 //       CHEKC:   return %[[RESULT]]
 
 // -----
@@ -450,7 +453,7 @@ func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i
   ^bb0(%arg4: index, %arg5: index):
     tensor.yield %1 : i32
   } : tensor<?x?xi32>
-  %3 = tensor.insert_slice %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
+  %3 = tensor.insert_slice %arg0 into %2[0, %arg3] [2, %0] [1, 1] : tensor<2x?xi32> into tensor<?x?xi32>
   return %3 : tensor<?x?xi32>
 }
 // CHECK-LABEL: func @insert_slice_propagate_dest_cast
@@ -462,9 +465,6 @@ func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i
 // -----
 
 func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
   %c9 = arith.constant 9 : index
   %c3 = arith.constant 3 : index
   %2 = tensor.extract %arg1[] : tensor<i32>
@@ -472,7 +472,7 @@ func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : ten
   ^bb0(%arg2: index, %arg3: index):
     tensor.yield %2 : i32
   } : tensor<?x?xi32>
-  %5 = tensor.insert_slice %arg0 into %4[%c0, %c1] [%c2, %c3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
+  %5 = tensor.insert_slice %arg0 into %4[0, 1] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
   %6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32>
   return %6 : tensor<3x9xi32>
 }
@@ -527,8 +527,9 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
 //      CHECK:    %[[r:.*]] =  tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
 //      CHECK:    return %[[r]]
 func @insert_tensor_cast_on_insert_slice_src(
-  %arg0 : tensor<?x5x?xf32>,  %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
-  %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1]
+    %arg0 : tensor<?x5x?xf32>,  %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
+  %c64 = arith.constant 64: index
+  %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [%c64, 5, %c64] [1, 1, 1]
     : tensor<?x5x?xf32> into tensor<?x?x?xf32>
   return %r : tensor<?x?x?xf32>
 }
@@ -559,13 +560,3 @@ func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8x
   // CHECK: return %[[INSERT]]
   return %1 : tensor<?x?x?xf32>
 }
-
-// -----
-
-// CHECK-LABEL: func @folding_incorrect_ir_triggers_infinite_loop
-func @folding_incorrect_ir_triggers_infinite_loop(
-  %A : tensor<4x4xf32>, %C : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %rC = tensor.insert_slice %A into %C[0, 0][12345, 67890][1, 1] :
-    tensor<4x4xf32> into tensor<?x?xf32>
-  return %rC: tensor<?x?xf32>
-}
index 51b67a76d14c373dc3c9c34840bb3d019cb434ea..f3c8ba28eb51eaedbfe481095efb8f70ef22cb16 100644 (file)
@@ -149,8 +149,36 @@ func @tensor.reshape_num_elements_mismatch(
 
 // -----
 
-func @slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
-      // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (mismatch of result sizes)}}
+func @extract_slice_wrong_result_rank(%t: tensor<?xf32>, %idx : index) {
+  // expected-error @+1 {{expected rank to be smaller or equal to the other rank.}}
+  %0 = tensor.extract_slice %t[0][4][1] : tensor<?xf32> to tensor<?x?xf32>
+
+  return
+}
+
+// -----
+
+func @extract_slice_wrong_result_rank(%t: tensor<?xf32>, %idx : index) {
+  // expected-error @+1 {{expected element type to be 'f32'}}
+  %0 = tensor.extract_slice %t[0][4][1] : tensor<?xf32> to tensor<4xi8>
+
+  return
+}
+
+// -----
+
+func @extract_slice_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
+  // expected-error @+1 {{expected type to be 'tensor<?x4x4xf32>' or a rank-reduced version. (size mismatch)}}
+  %0 = tensor.extract_slice %t[0, 0, 0][%idx, 4, 4][1, 1, 1]
+    : tensor<8x16x4xf32> to tensor<4x4x4xf32>
+
+  return
+}
+
+// -----
+
+func @extract_slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
+  // expected-error @+1 {{expected type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (size mismatch)}}
   %0 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1]
     : tensor<8x16x4xf32> to tensor<?x4x4xf32>
 
@@ -159,10 +187,38 @@ func @slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
 
 // -----
 
-func @slice_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
-      // expected-error @+1 {{expected result type to be 'tensor<?x3x?xf32>' or a rank-reduced version. (mismatch of result sizes)}}
-  %0 = tensor.extract_slice %t[0, 0, 0][%idx, 3, %idx][1, 1, 1]
-    : tensor<8x16x4xf32> to tensor<4x4x4xf32>
+func @insert_slice_wrong_result_rank(%t1: tensor<?xf32>, %t2: tensor<?x?xf32>, %idx : index) {
+  // expected-error @+1 {{expected rank to be smaller or equal to the other rank.}}
+  %0 = tensor.insert_slice %t2 into %t1[0][4][1] : tensor<?x?xf32> into tensor<?xf32>
+
+  return
+}
+
+// -----
+
+func @insert_slice_wrong_result_rank(%t1: tensor<4xi8>, %t2: tensor<?xf32>, %idx : index) {
+  // expected-error @+1 {{expected element type to be 'f32'}}
+  %0 = tensor.insert_slice %t1 into %t2[0][4][1] : tensor<4xi8> into tensor<?xf32>
+
+  return
+}
+
+// -----
+
+func @insert_slice_wrong_static_type(%t1: tensor<4x4x4xf32>, %t2: tensor<8x16x4xf32>, %idx : index) {
+  // expected-error @+1 {{expected type to be 'tensor<?x4x4xf32>' or a rank-reduced version. (size mismatch)}}
+  %0 = tensor.insert_slice %t1 into %t2[0, 0, 0][%idx, 4, 4][1, 1, 1]
+    : tensor<4x4x4xf32> into tensor<8x16x4xf32>
+
+  return
+}
+
+// -----
+
+func @insert_slice_wrong_dynamic_type(%t1: tensor<?x4x4xf32>, %t2: tensor<8x16x4xf32>, %idx : index) {
+  // expected-error @+1 {{expected type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (size mismatch)}}
+  %0 = tensor.insert_slice %t1 into %t2[0, 2, 0][4, 4, 4][1, 1, 1]
+    : tensor<?x4x4xf32> into tensor<8x16x4xf32>
 
   return
 }
index 2353d0320b4e050a1bb8a251056ca1abf465ae7c..d8c5a415fcb8b27e7209d4c6de7338f91951f84f 100644 (file)
@@ -78,3 +78,60 @@ func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>,
                : (tensor<?x?xf32>, tensor<?xi32>) -> tensor<*xf32>
   return %new_unranked : tensor<*xf32>
 }
+
+// CHECK-LABEL: func @slice({{.*}}) {
+func @slice(%t: tensor<8x16x4xf32>, %idx : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  // CHECK: tensor.extract_slice
+  // CHECK-SAME: tensor<8x16x4xf32> to tensor<?x?x?xf32>
+  %1 = tensor.extract_slice %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
+    : tensor<8x16x4xf32> to tensor<?x?x?xf32>
+
+  // CHECK: tensor.extract_slice
+  // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32>
+  %2 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1]
+    : tensor<8x16x4xf32> to tensor<4x4x4xf32>
+
+  // CHECK: tensor.extract_slice
+  // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32>
+  %3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1]
+    : tensor<8x16x4xf32> to tensor<4x4xf32>
+
+  return
+}
+
+// CHECK-LABEL: func @insert_slice({{.*}}) {
+func @insert_slice(
+    %t: tensor<8x16x4xf32>,
+    %td: tensor<8x?x4xf32>,
+    %t2: tensor<16x32x8xf32>,
+    %t3: tensor<4x4xf32>,
+    %idx : index,
+    %sz : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  // CHECK: tensor.insert_slice
+  // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
+  %1 = tensor.insert_slice %t into %t2[%c0, %c0, %c0][8, 16, 4][%c1, %c1, %c1]
+    : tensor<8x16x4xf32> into tensor<16x32x8xf32>
+
+  // CHECK: tensor.insert_slice
+  // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
+  %2 = tensor.insert_slice %t into %t2[%c0, %idx, %c0][8, 16, 4][%c1, 1, %c1]
+    : tensor<8x16x4xf32> into tensor<16x32x8xf32>
+
+  // CHECK: tensor.insert_slice
+  // CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32>
+  %3 = tensor.insert_slice %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1]
+    : tensor<4x4xf32> into tensor<8x16x4xf32>
+
+  // CHECK: tensor.insert_slice
+  // CHECK-SAME: tensor<8x?x4xf32> into tensor<8x16x4xf32>
+  %4 = tensor.insert_slice %td into %t[0, %idx, 0][8, %sz, 4][1, 1, 1]
+    : tensor<8x?x4xf32> into tensor<8x16x4xf32>
+
+  return
+}
index 029d26c0c71e6849c0df5904590ab193e0f1d872..300c2542d807851ec700209dab58d53a436febdb 100644 (file)
@@ -486,53 +486,3 @@ func @assume_alignment(%0: memref<4x4xf16>) {
   memref.assume_alignment %0, 16 : memref<4x4xf16>
   return
 }
-
-// CHECK-LABEL: func @slice({{.*}}) {
-func @slice(%t: tensor<8x16x4xf32>, %idx : index) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-
-  // CHECK: tensor.extract_slice
-  // CHECK-SAME: tensor<8x16x4xf32> to tensor<?x?x?xf32>
-  %1 = tensor.extract_slice %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
-    : tensor<8x16x4xf32> to tensor<?x?x?xf32>
-
-  // CHECK: tensor.extract_slice
-  // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32>
-  %2 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1]
-    : tensor<8x16x4xf32> to tensor<4x4x4xf32>
-
-  // CHECK: tensor.extract_slice
-  // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32>
-  %3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1]
-    : tensor<8x16x4xf32> to tensor<4x4xf32>
-
-  return
-}
-
-// CHECK-LABEL: func @insert_slice({{.*}}) {
-func @insert_slice(
-    %t: tensor<8x16x4xf32>,
-    %t2: tensor<16x32x8xf32>,
-    %t3: tensor<4x4xf32>,
-    %idx : index) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-
-  // CHECK: tensor.insert_slice
-  // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
-  %1 = tensor.insert_slice %t into %t2[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
-    : tensor<8x16x4xf32> into tensor<16x32x8xf32>
-
-  // CHECK: tensor.insert_slice
-  // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
-  %2 = tensor.insert_slice %t into %t2[%c0, %idx, %c0][%idx, 4, %idx][%c1, 1, %c1]
-    : tensor<8x16x4xf32> into tensor<16x32x8xf32>
-
-  // CHECK: tensor.insert_slice
-  // CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32>
-  %3 = tensor.insert_slice %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1]
-    : tensor<4x4xf32> into tensor<8x16x4xf32>
-
-  return
-}