#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <numeric>
+
using namespace mlir;
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
}
};
+class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
+public:
+ using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const final {
+ typename tosa::ReshapeOp::Adaptor operands(args);
+
+ ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
+ ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+
+ if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
+ return failure();
+
+ // Compute the reassociation maps for the linalg operation.
+ ArrayRef<int64_t> expandedShape =
+ (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
+ : resultTy.getShape());
+ ArrayRef<int64_t> collapsedShape =
+ (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
+ : operandTy.getShape());
+ unsigned currSrcDim = 0, currDstDim = 0;
+ SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
+ collapsedShape.size());
+
+ // First scan all dimensions in the source shapes to see whether we have a
+ // perfect case where consecutive dimensions in source are collapsed. For
+ // such case we can just generate one single linalg.reshape.
+ bool isCollapsingSource = true;
+ while (currSrcDim < expandedShape.size() &&
+ currDstDim < collapsedShape.size()) {
+ int64_t dstSize = collapsedShape[currDstDim];
+ int64_t srcSize = expandedShape[currSrcDim];
+ while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
+ reassociationMap[currDstDim].push_back(
+ rewriter.getAffineDimExpr(currSrcDim++));
+ srcSize *= expandedShape[currSrcDim];
+ }
+ if (srcSize == dstSize) {
+ reassociationMap[currDstDim].push_back(
+ rewriter.getAffineDimExpr(currSrcDim++));
+ // If the next dim in collapsedShape is not 1, treat subsequent dims in
+ // expandedShape which are 1 to be collapsed.
+ if (currDstDim == collapsedShape.size() - 1 ||
+ collapsedShape[currDstDim + 1] != 1) {
+ while (currSrcDim < expandedShape.size() &&
+ expandedShape[currSrcDim] == 1) {
+ reassociationMap[currDstDim].push_back(
+ rewriter.getAffineDimExpr(currSrcDim++));
+ }
+ }
+ } else {
+ isCollapsingSource = false;
+ break;
+ }
+ currDstDim++;
+ }
+ if (currSrcDim != expandedShape.size() ||
+ currDstDim != collapsedShape.size())
+ isCollapsingSource = false;
+
+ // Otherwise, we need to first reduce all source dimensions into one and
+ // then expand to the destination dimensions.
+ if (!isCollapsingSource) {
+ auto getIdentityExprs = [&rewriter](int n) {
+ SmallVector<AffineExpr, 4> exprs;
+ for (int i = 0; i < n; ++i)
+ exprs.push_back(rewriter.getAffineDimExpr(i));
+ return exprs;
+ };
+ Location loc = reshape.getLoc();
+ int64_t totalElems =
+ std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
+ std::multiplies<int64_t>());
+ auto elemTy = operandTy.getElementType();
+ SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
+ // Use operandTy here because we need to collapse all operands
+ // dimensions.
+ getIdentityExprs(operandTy.getShape().size())};
+ SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
+ // Use resultTy here because we need to expand to all result
+ // dimensions.
+ getIdentityExprs(resultTy.getShape().size())};
+
+ auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
+ Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
+ loc, collapsedTy, args[0], collapsingMap);
+ rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
+ reshape, resultTy, collapsedOp, expandingMap);
+
+ return success();
+ }
+
+ rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
+ reshape, resultTy, args[0], reassociationMap);
+ return success();
+ }
+};
+
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
PointwiseConverter<tosa::GreaterEqualOp>,
PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
- PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>>(
- context);
+ PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
+ ReshapeOpConverter>(context);
}
return
}
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: @test_reshape_downrank
+func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
+ // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32>
+ // CHECK: return [[RESHAPE]]
+ return %0 : tensor<6xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: @test_reshape_uprank
+func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
+ // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32>
+ // CHECK: return [[RESHAPE]]
+ return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: @test_reshape_samerank
+func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
+ // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
+ // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape [[RESHAPE1]] [#[[$MAP0]]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32>
+ // CHECK: return [[RESHAPE2]]
+ return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+
+// CHECK-LABEL: @test_reshape_downrank_6D
+func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+ // CHECK: linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+ return %0 : tensor<6x5x77xf32>
+}