return success();
}
};
+
+template <typename OpTy>
+struct CanonicalizeCastExtentTensorOperandsPattern
+ : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Canonicalize operands.
+ bool anyChange = false;
+ auto canonicalizeOperand = [&](Value operand) {
+ if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
+ // Only eliminate the cast if it holds no shape information.
+ bool isInformationLoosingCast =
+ castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
+ if (isInformationLoosingCast) {
+ anyChange = true;
+ return castOp.source();
+ }
+ }
+ return operand;
+ };
+ auto newOperands = llvm::to_vector<8>(
+ llvm::map_range(op.getOperands(), canonicalizeOperand));
+
+ // Rewrite op if any change required.
+ if (!anyChange)
+ return failure();
+ rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<BroadcastFoldConstantOperandsPattern,
BroadcastForwardSingleOperandPattern,
+ CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
RemoveDuplicateOperandsPattern<BroadcastOp>,
RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
}
// Canonicalization patterns have overlap with the considerations during
// folding in case additional shape information is inferred at some point that
// does not result in folding.
- patterns.add<CstrBroadcastableEqOps,
+ patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
+ CstrBroadcastableEqOps,
RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
}
// ```
// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
// ```
-struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
+struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CastOp op,
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<ShapeOfCastedExtentTensor, ShapeOfWithTensor>(context);
+ patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
}
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @fold_index_cast_on_index
func @fold_index_cast_on_index(%arg: index) -> index {
// CHECK-NOT: size_to_index
- %casted = shape.size_to_index %arg : index
- return %casted : index
+ %0 = shape.size_to_index %arg : index
+ return %0 : index
}
// -----
// CHECK-LABEL: @fold_to_extent_tensor_on_tensor
func @fold_to_extent_tensor_on_tensor(%arg: tensor<?xindex>) -> tensor<?xindex> {
// CHECK-NOT: to_extent_tensor
- %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
- return %casted : tensor<?xindex>
+ %0 = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
+ return %0 : tensor<?xindex>
}
// -----
// -----
-// CHECK-LABEL: @casted_extent_tensor
+// CHECK-LABEL: @cast_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
-func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
+func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
// -----
-// CHECK-LABEL: @casted_extent_tensor
+// CHECK-LABEL: @cast_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<3xindex>
-func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
+func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<3xindex>
// CHECK: return %[[RESULT]] : tensor<3xindex>
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
// -----
-// CHECK-LABEL: @casted_extent_tensor
-func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
+// CHECK-LABEL: @cast_extent_tensor
+func @cast_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
%0 = shape.shape_of %arg : tensor<?x?x?x?xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
// -----
-// CHECK-LABEL: @casted_extent_tensor
-func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
+// CHECK-LABEL: @cast_extent_tensor
+func @cast_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
%0 = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
%2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex>
"use"(%2) : (!shape.witness) -> ()
}
+
+// -----
+
+// CHECK-LABEL: @cast_extent_tensor_operands
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<3xindex>)
+func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
+ %arg1 : tensor<3xindex>) -> (!shape.witness, tensor<?xindex>) {
+ // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?xindex> to tensor<3xindex>
+ // CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
+ // CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
+ // CHECK: return %[[WIT]], %[[RES]]
+ %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<3xindex>
+ %1 = tensor.cast %arg1 : tensor<3xindex> to tensor<?xindex>
+ %2 = shape.cstr_broadcastable %0, %1 : tensor<3xindex>, tensor<?xindex>
+ %3 = shape.broadcast %0, %1 :tensor<3xindex>, tensor<?xindex>
+ -> tensor<?xindex>
+ return %2, %3 : !shape.witness, tensor<?xindex>
+}