From 10d7924581f8f29c558d089c2546321de26f8849 Mon Sep 17 00:00:00 2001 From: Ahmed Taei Date: Mon, 19 Apr 2021 16:52:47 -0700 Subject: [PATCH] Fix FoldReshapeOpWithUnitExtent generating illegal reshape This will prevent fusion that spains all dims and generates (d0, d1, ...) -> () reshape that isn't legal Differential Revision: https://reviews.llvm.org/D100805 --- mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 9 +++++++++ mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 1554059..5d8a664 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -518,7 +518,16 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern { } else { return failure(); } + foldedDim++; + // If inner most dims are folded there shouldn't be any leading 1 dims. + // otherwise these dims are not mapped and will lead into an illegal + // reshape. + if (expandedDim == expandedShape.size()) { + if (foldedDim < foldedShape.size() && foldedShape[foldedDim] == 1) { + return failure(); + } + } } if (expandedDim != expandedShape.size()) return failure(); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 8a36f6d..e9dd74f 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -647,3 +647,21 @@ func @unit_dim_for_reduction_inner(%arg0: tensor) -> tensor) // CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]] // CHECK: return %[[RESULT_RESHAPE]] + +// ----- + +func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> + %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>] : tensor<3x2x2x1xf32> into tensor<12x1xf32> + return %1 : tensor<12x1xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: func @no_fold_reshape_empty_expr +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> +// CHECK: %[[RARG0:.+]] = linalg.tensor_reshape %[[ARG0:.+]] [#[[MAP0]], #[[MAP1]], #[[MAP2]] +// CHECK: %[[RES:.+]] = linalg.tensor_reshape %[[RARG0:.+]] [#[[MAP3]], #[[MAP4]]] +// CHECK: return %[[RES:.+]] : tensor<12x1xf32> -- 2.7.4