namespace mlir {
namespace tensor {
+/// Returns true if `target` is a ranked tensor type that preserves static
+/// information available in the `source` ranked tensor type.
+bool preservesStaticInformation(Type source, Type target);
+
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may
return success();
}
};
+
+// Fold CastOp using the result of PadTensorOp back into the latter if it adds
+// static information.
+struct FoldTargetTensorCast : public OpRewritePattern<PadTensorOp> {
+ using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
+ PatternRewriter &rewriter) const override {
+ if (!padTensorOp.result().hasOneUse())
+ return failure();
+ auto tensorCastOp =
+ dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
+ if (!tensorCastOp)
+ return failure();
+ if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
+ tensorCastOp.dest().getType()))
+ return failure();
+
+ auto replacementOp = rewriter.create<PadTensorOp>(
+ padTensorOp.getLoc(), tensorCastOp.dest().getType(),
+ padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
+ padTensorOp.static_low(), padTensorOp.static_high());
+ replacementOp.region().takeBody(padTensorOp.region());
+
+ rewriter.replaceOp(padTensorOp, replacementOp.result());
+ rewriter.replaceOp(tensorCastOp, replacementOp.result());
+ return success();
+ }
+};
} // namespace
void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast>(context);
+ results.add<FoldTargetTensorCast>(context);
}
/// Return the padding value of the PadTensorOp if it constant. In this context,
// CastOp
//===----------------------------------------------------------------------===//
+/// Returns true if `target` is a ranked tensor type that preserves static
+/// information available in the `source` ranked tensor type.
+bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
+ auto sourceType = source.dyn_cast<RankedTensorType>();
+ auto targetType = target.dyn_cast<RankedTensorType>();
+
+ // Requires RankedTensorType.
+ if (!sourceType || !targetType)
+ return false;
+
+ // Requires same elemental type.
+ if (sourceType.getElementType() != targetType.getElementType())
+ return false;
+
+ // Requires same rank.
+ if (sourceType.getRank() != targetType.getRank())
+ return false;
+
+ // If cast is towards more static sizes along any dimension, don't fold.
+ for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
+ if (!ShapedType::isDynamic(std::get<0>(t)) &&
+ ShapedType::isDynamic(std::get<1>(t)))
+ return false;
+ }
+
+ return true;
+}
+
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may
if (!castOp)
return false;
- RankedTensorType sourceType =
- castOp.source().getType().dyn_cast<RankedTensorType>();
- RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
-
- // Requires RankedTensorType.
- if (!sourceType || !resultType)
- return false;
-
- // Requires same elemental type.
- if (sourceType.getElementType() != resultType.getElementType())
- return false;
-
- // Requires same rank.
- if (sourceType.getRank() != resultType.getRank())
- return false;
-
- // If cast is towards more static sizes along any dimension, don't fold.
- for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
- if (ShapedType::isDynamic(std::get<0>(t)) &&
- !ShapedType::isDynamic(std::get<1>(t)))
- return false;
- }
-
- return true;
+ // Can fold if the source of cast has at least as much static information as
+ // its results.
+ return preservesStaticInformation(castOp.getType(),
+ castOp.source().getType());
}
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
// -----
+// CHECK-LABEL: @cast_of_pad_more_static
+func @cast_of_pad_more_static(%arg0: tensor<?x?xf32>, %padding: index) -> tensor<32x32xf32> {
+ %cst = constant 0.000000e+00 : f32
+ // CHECK: %[[PAD:.*]] = linalg.pad_tensor
+ // CHECK: tensor<?x?xf32> to tensor<32x32xf32>
+ %padded = linalg.pad_tensor %arg0 low[%padding, %padding] high[0, 0] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<?x?xf32>
+ // CHECK-NOT: tensor.cast
+ %casted = tensor.cast %padded : tensor<?x?xf32> to tensor<32x32xf32>
+ // CHECK: return %[[PAD]]
+ return %casted : tensor<32x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @cast_of_pad_less_static
+func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor<?x32x32xf32> {
+ %cst = constant 0.000000e+00 : f32
+ // CHECK: linalg.pad_tensor
+ %padded = linalg.pad_tensor %arg0 low[%padding, %padding, %padding] high[0, 0, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index):
+ linalg.yield %cst : f32
+ } : tensor<32x?x?xf32> to tensor<32x?x?xf32>
+ // CHECK: %[[CAST:.*]] = tensor.cast
+ %casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor<?x32x32xf32>
+ // CHECK: return %[[CAST]]
+ return %casted : tensor<?x32x32xf32>
+}
+
+// -----
+
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
%arg3 : index) -> tensor<?x?xf32> {
%c0 = constant 0 : index
// CHECK: } else {
// CHECK: %[[SUBTENSOR:.*]] = tensor.extract_slice %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor<?x5xf32> to tensor<?x1xf32>
// CHECK: %[[PADTENSOR:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3]
-// CHECK: %[[CAST:.*]] = tensor.cast %[[PADTENSOR]] : tensor<?x4xf32> to tensor<3x4xf32>
-// CHECK: scf.yield %[[CAST]]
+// CHECK: scf.yield %[[PADTENSOR]]
// CHECK: }
// CHECK: return %[[RESULT]]
func @dynamic_high_pad(%arg0 : tensor<?x5xf32>, %h1: index, %pad : f32) -> tensor<3x4xf32> {
// CHECK: else
// CHECK: tensor.extract_slice
// CHECK: linalg.pad_tensor
-// CHECK: tensor.cast
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
// CHECK: linalg.generic
// TILE1: else
// TILE1: %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
// TILE1: %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[0, 0] high[7, %{{.*}}]
-// TILE1: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32>
-// TILE1: scf.yield %[[CAST]] : tensor<14x3xf32>
+// TILE1: scf.yield %[[PAD]] : tensor<14x3xf32>
// TILE1: %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
// TILE1: scf.yield %[[R3]] : tensor<14x15xf32>
// TILE1: return %[[RESULT]] : tensor<14x15xf32>