return success();
}
+/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
+/// rank-extending tensor.insert_slice op.
+static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
+ ArrayRef<OpFoldResult> mixedSizes) {
+ llvm::SmallBitVector droppedDims(mixedSizes.size());
+ int64_t shapePos = 0;
+
+ for (const auto &size : enumerate(mixedSizes)) {
+ // Rank-reduced dims must have a static unit dimension.
+ bool isStaticUnitSize =
+ size.value().is<Attribute>() &&
+ size.value().get<Attribute>().cast<IntegerAttr>().getInt() == 1;
+
+ if (shapePos == static_cast<int64_t>(reducedShape.size())) {
+ // There are no more dims in the reduced shape. All remaining sizes must
+ // be rank-reduced dims.
+ assert(isStaticUnitSize && "expected unit dim");
+ droppedDims.set(size.index());
+ continue;
+ }
+
+ // Dim is preserved if the size is not a static 1.
+ if (!isStaticUnitSize) {
+ ++shapePos;
+ continue;
+ }
+
+ // Dim is preserved if the reduced shape dim is also 1.
+ if (reducedShape[shapePos] == 1) {
+ ++shapePos;
+ continue;
+ }
+
+ // Otherwise: Dim is dropped.
+ droppedDims.set(size.index());
+ }
+
+ assert(shapePos == static_cast<int64_t>(reducedShape.size()) &&
+ "dimension mismatch");
+ return droppedDims;
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
}
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
- ArrayRef<int64_t> resultShape = getType().getShape();
- SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
- llvm::SmallBitVector droppedDims(mixedSizes.size());
- unsigned shapePos = 0;
- for (const auto &size : enumerate(mixedSizes)) {
- std::optional<int64_t> sizeVal = getConstantIntValue(size.value());
- // If the size is not 1, or if the current matched dimension of the result
- // is the same static shape as the size value (which is 1), then the
- // dimension is preserved.
- if (!sizeVal || *sizeVal != 1 ||
- (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
- shapePos++;
- continue;
- }
- droppedDims.set(size.index());
- }
- return droppedDims;
+ return ::getDroppedDims(getType().getShape(), getMixedSizes());
}
FailureOr<Value>
} // namespace
llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
- ArrayRef<int64_t> resultShape = getType().getShape();
- SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
- llvm::SmallBitVector droppedDims(mixedSizes.size());
- unsigned shapePos = 0;
- for (const auto &size : enumerate(mixedSizes)) {
- std::optional<int64_t> sizeVal = getConstantIntValue(size.value());
- // If the size is not 1, or if the current matched dimension of the result
- // is the same static shape as the size value (which is 1), then the
- // dimension is preserved.
- if (!sizeVal || *sizeVal != 1 ||
- (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
- shapePos++;
- continue;
- }
- droppedDims.set(size.index());
- }
- return droppedDims;
+ return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
}
void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,