return nullptr;
// Skip cast-like operations.
- // TODO: CastOpInterface could be used if CollapseShapeOp and ExpandShapeOp
- // implement that interface
values.clear();
llvm::TypeSwitch<Operation *>(defOp)
.Case<CastOp>([&](CastOp op) { values.push_back(op.getSource()); })
if (isCastLikeInsertSliceOp(op))
values.push_back(op.getSource());
})
+ .Case<ExtractSliceOp>([&](ExtractSliceOp op) {
+ if (isCastLikeExtractSliceOp(op))
+ values.push_back(op.getSource());
+ })
.Default([](Operation *op) {});
} while (!values.empty());
return true;
}
+
+bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
+ llvm::SmallBitVector droppedDims = op.getDroppedDims();
+ int64_t resultDim = 0;
+ // Source dims and result dims (apart from dropped dims) must have the same
+ // size.
+ for (int64_t dim = 0; dim < op.getSourceType().getRank(); ++dim) {
+ if (droppedDims.test(dim)) {
+ continue;
+ }
+ FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
+ op.getSource(), op.getResult(), dim, resultDim);
+ if (failed(equalDimSize) || !*equalDimSize)
+ return false;
+ ++resultDim;
+ }
+
+ return true;
+}
{replacement_0 = 0} : tensor<?xf32> into tensor<1x?x1xf32>
return
}
+
+// -----
+
+func.func @cast_like_extract_slice() {
+ %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
+ // expected-remark @below {{replacement found}}
+ %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
+ %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 5, 1, 1][1, 1, 1, 1]
+ {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<5xf32>
+ return
+}
+
+// -----
+
+func.func @cast_like_extract_slice_dynamic() {
+ %0 = "test.foo"() {replaced} : () -> (tensor<?xf32>)
+ // expected-remark @below {{replacement found}}
+ %1 = "test.foo"() : () -> (tensor<1x?x1x1xf32>)
+ %c1 = arith.constant 1 : index
+ %dim = tensor.dim %1, %c1 : tensor<1x?x1x1xf32>
+ %2 = tensor.extract_slice %1[0, 0, 0, 0][1, %dim, 1, 1][1, 1, 1, 1]
+ {replacement_0 = 0} : tensor<1x?x1x1xf32> to tensor<?xf32>
+ return
+}
+
+// -----
+
+func.func @non_cast_like_extract_slice() {
+ // expected-error @below {{listener could not find replacement op}}
+ %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
+ %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
+ %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 3, 1, 1][1, 1, 1, 1]
+ {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<3xf32>
+ return
+}