From 34cf67aef5a3655b57e52842a1bb4913295076e4 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 1 Jun 2023 09:00:08 +0200 Subject: [PATCH] [mlir][tensor] TrackingListener: Find replacement ops through cast-like ExtractSliceOps Certain ExtractSliceOps, that do extract all elements from the destination, are treated like casts when looking for replacement ops. Such ExtractSliceOps are typically rank expansions. Differential Revision: https://reviews.llvm.org/D151804 --- mlir/include/mlir/Dialect/Tensor/Utils/Utils.h | 4 +++ .../Tensor/TransformOps/TensorTransformOps.cpp | 6 ++-- mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 19 ++++++++++++ mlir/test/Dialect/Tensor/tracking-listener.mlir | 35 ++++++++++++++++++++++ 4 files changed, 62 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h index a037d40..c610b5d 100644 --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -47,6 +47,10 @@ computeTransposedType(RankedTensorType rankedTensorType, /// the same shape. bool isCastLikeInsertSliceOp(InsertSliceOp op); +/// A tensor.extract_slice is a cast-like operation if it merely rank-reduces +/// the source tensor or extracts the entire source tensor. +bool isCastLikeExtractSliceOp(ExtractSliceOp op); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 9b609a2..09a6b50 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -38,8 +38,6 @@ tensor::TrackingListener::findReplacementOp(Operation *op, return nullptr; // Skip cast-like operations. - // TODO: CastOpInterface could be used if CollapseShapeOp and ExpandShapeOp - // implement that interface values.clear(); llvm::TypeSwitch(defOp) .Case([&](CastOp op) { values.push_back(op.getSource()); }) @@ -53,6 +51,10 @@ tensor::TrackingListener::findReplacementOp(Operation *op, if (isCastLikeInsertSliceOp(op)) values.push_back(op.getSource()); }) + .Case([&](ExtractSliceOp op) { + if (isCastLikeExtractSliceOp(op)) + values.push_back(op.getSource()); + }) .Default([](Operation *op) {}); } while (!values.empty()); diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp index 165cf9b..4d5404a 100644 --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -123,3 +123,22 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { return true; } + +bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) { + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + int64_t resultDim = 0; + // Source dims and result dims (apart from dropped dims) must have the same + // size. + for (int64_t dim = 0; dim < op.getSourceType().getRank(); ++dim) { + if (droppedDims.test(dim)) { + continue; + } + FailureOr equalDimSize = ValueBoundsConstraintSet::areEqual( + op.getSource(), op.getResult(), dim, resultDim); + if (failed(equalDimSize) || !*equalDimSize) + return false; + ++resultDim; + } + + return true; +} diff --git a/mlir/test/Dialect/Tensor/tracking-listener.mlir b/mlir/test/Dialect/Tensor/tracking-listener.mlir index 369dcec..6341b7a 100644 --- a/mlir/test/Dialect/Tensor/tracking-listener.mlir +++ b/mlir/test/Dialect/Tensor/tracking-listener.mlir @@ -105,3 +105,38 @@ func.func @cast_like_insert_slice_dynamic( {replacement_0 = 0} : tensor into tensor<1x?x1xf32> return } + +// ----- + +func.func @cast_like_extract_slice() { + %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) + // expected-remark @below {{replacement found}} + %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>) + %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 5, 1, 1][1, 1, 1, 1] + {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<5xf32> + return +} + +// ----- + +func.func @cast_like_extract_slice_dynamic() { + %0 = "test.foo"() {replaced} : () -> (tensor) + // expected-remark @below {{replacement found}} + %1 = "test.foo"() : () -> (tensor<1x?x1x1xf32>) + %c1 = arith.constant 1 : index + %dim = tensor.dim %1, %c1 : tensor<1x?x1x1xf32> + %2 = tensor.extract_slice %1[0, 0, 0, 0][1, %dim, 1, 1][1, 1, 1, 1] + {replacement_0 = 0} : tensor<1x?x1x1xf32> to tensor + return +} + +// ----- + +func.func @non_cast_like_extract_slice() { + // expected-error @below {{listener could not find replacement op}} + %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) + %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>) + %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 3, 1, 1][1, 1, 1, 1] + {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<3xf32> + return +} -- 2.7.4