unsigned unsignedIndex = index.getValue().getZExtValue();
if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
- assert(sliceOp.isDynamicSize(unsignedIndex) &&
- "Expected dynamic slice size");
- return sliceOp.getDynamicSize(unsignedIndex);
+ // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
+ // `resolve-shaped-type-result-dims` pass.
+ if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
+ sliceOp.isDynamicSize(unsignedIndex)) {
+ return {sliceOp.getDynamicSize(unsignedIndex)};
+ }
}
// dim(cast) -> dim
return resultType;
}
+llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
+ llvm::SmallDenseSet<unsigned> droppedDims;
+ ArrayRef<int64_t> resultShape = getType().getShape();
+ SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
+ unsigned shapePos = 0;
+ for (auto size : enumerate(mixedSizes)) {
+ Optional<int64_t> sizeVal = getConstantIntValue(size.value());
+ // If the size is not 1, or if the current matched dimension of the result
+ // is the same static shape as the size value (which is 1), then the
+ // dimension is preserved.
+ if (!sizeVal || sizeVal.getValue() != 1 ||
+ (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
+ shapePos++;
+ continue;
+ }
+ droppedDims.insert(size.index());
+ }
+ return droppedDims;
+}
+
+LogicalResult ExtractSliceOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ reifiedReturnShapes.resize(1);
+ reifiedReturnShapes[0].reserve(getType().getRank());
+ SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
+ llvm::SmallDenseSet<unsigned> droppedDims = getDroppedDims();
+ Location loc = getLoc();
+ for (auto size : enumerate(mixedSizes)) {
+ if (droppedDims.count(size.index()))
+ continue;
+ if (auto attr = size.value().dyn_cast<Attribute>()) {
+ reifiedReturnShapes[0].push_back(builder.create<ConstantIndexOp>(
+ loc, attr.cast<IntegerAttr>().getInt()));
+ continue;
+ }
+ reifiedReturnShapes[0].push_back(size.value().get<Value>());
+ }
+ return success();
+}
+
namespace {
/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
/// This essentially pushes memref_cast past its consuming slice when
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK: return %[[D0]], %[[D1]], %[[D2]]
+
+// -----
+
+func @extract_slice(%arg0 : tensor<?x?x?xf32>, %arg1 : index, %arg2 : index,
+ %arg3 : index) -> (index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, %arg2, %arg3] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?x?x?xf32>
+ %2 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
+ %3 = tensor.dim %0, %c2 : tensor<?x?x?xf32>
+ return %1, %2, %3 : index, index, index
+}
+// CHECK-LABEL: func @extract_slice(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]], %[[ARG2]], %[[ARG3]]
+
+// -----
+
+func @extract_slice_rank_reduced_1(%arg0 : tensor<?x?x?xf32>,
+ %arg1 : index) -> index {
+ %c0 = constant 0 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<?xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?xf32>
+ return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_1(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_2(%arg0 : tensor<?x?x?xf32>,
+ %arg1 : index) -> index {
+ %c0 = constant 0 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<?x1xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?x1xf32>
+ return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_2(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_3(%arg0 : tensor<?x?x?xf32>,
+ %arg1 : index) -> index {
+ %c1 = constant 1 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<1x?xf32>
+ %1 = tensor.dim %0, %c1 : tensor<1x?xf32>
+ return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_3(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_4(%arg0 : tensor<?x?x?xf32>,
+ %arg1 : index) -> index {
+ %c1 = constant 1 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<1x?x1xf32>
+ %1 = tensor.dim %0, %c1 : tensor<1x?x1xf32>
+ return %1 : index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_4(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @extract_slice_rank_reduced_5(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<?x?xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_5(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]], %[[ARG2]]
+
+// -----
+
+func @extract_slice_rank_reduced_6(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
+ tensor<?x?x?xf32> to tensor<?x1x?xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?x1x?xf32>
+ %2 = tensor.dim %0, %c2 : tensor<?x1x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK-LABEL: func @extract_slice_rank_reduced_6(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]], %[[ARG2]]