Do not generate CollapseShapeOps/ExpandShapeOps that have the same source and result shape. Generate casts instead. Such reshapes became invalid with D138498.
Differential Revision: https://reviews.llvm.org/D138557
//
/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
/// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
SmallVector<ReassociationIndices, 4> higherRankReassociation,
lowerRankReassociation;
- bool isResultCollapsed = srcRank > resultRank;
- if (isResultCollapsed) {
+ if (srcRank > resultRank) {
higherRankReassociation = expandOp.getReassociationIndices();
lowerRankReassociation = collapseOp.getReassociationIndices();
} else {
}
composedReassociation.push_back(composedIndices);
}
- if (isResultCollapsed)
+ if (srcRank > resultRank) {
rewriter.replaceOpWithNewOp<CollapseOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
- else
+ } else if (srcRank < resultRank) {
rewriter.replaceOpWithNewOp<ExpandOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
+ } else {
+ // Collapses/expansions that do not change the rank are not allowed. Use
+ // a cast instead.
+ assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
+ "expected same shape");
+ rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
+ expandOp.getSrc());
+ }
return success();
}
};
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
- ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+ ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
CollapseShapeOpMemRefCastFolder>(context);
}
MLIRContext *context) {
results
.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
- ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+ ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
FoldReshapeWithConstant<CollapseShapeOp>,
FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
context);
memref.store %v, %0[%i2] : memref<4xf32>
return %src : memref<2xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @collapse_expand_fold_to_cast(
+// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[1]>, 3>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<?xf32, strided<[1]>, 3> to memref<?xf32, 3
+// CHECK: return %[[casted]]
+func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>)
+ -> (memref<?xf32, 3>)
+{
+ %0 = memref.expand_shape %m [[0, 1]]
+ : memref<?xf32, strided<[1]>, 3> into memref<1x?xf32, 3>
+ %1 = memref.collapse_shape %0 [[0, 1]]
+ : memref<1x?xf32, 3> into memref<?xf32, 3>
+ return %1 : memref<?xf32, 3>
+}
%1 = tensor.dim %0, %c1 : tensor<?x?xf32>
return %1 : index
}
+
+// -----
+
+// CHECK-LABEL: func @collapse_expand_fold_to_cast(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
+// CHECK: return %[[t]]
+func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>) -> (tensor<?xf32>)
+{
+ %0 = tensor.expand_shape %t [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+ return %1 : tensor<?xf32>
+}