From: Matthias Gehre Date: Thu, 13 Jul 2023 06:53:47 +0000 (+0200) Subject: [MLIR] [TOSA]: Move reshape(reshape(x)) -> reshape(x) from canonicalization to fold X-Git-Tag: upstream/17.0.6~1553 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=0ebb0503113e97eb14dc679f06bdc1d2e7296d54;p=platform%2Fupstream%2Fllvm.git [MLIR] [TOSA]: Move reshape(reshape(x)) -> reshape(x) from canonicalization to fold reshape(reshape(x)) -> reshape(x) can be directly written as a fold instead of a canonicalization, to help other passes cleanup while they work. This initially broke ReshapeConverterExpand/Collapse, which relies on creating foldable reshapes and a carefully crafted benefit priority of patterns. I turned this into a single pattern on reshapes, which does expand and/or collapse as needed in one go. Differential Revision: https://reviews.llvm.org/D155266 --- diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index e5b4e66..7a16b37 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1480,7 +1480,6 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [ No data conversion happens during a reshape operation. }]; - let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 1a9f48e3..f51ada8 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -129,81 +129,74 @@ static bool createReassociationMapsForCollapse( } namespace { -class ReshapeConverterCollapse : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = cast(adaptor.getInput1().getType()); - ShapedType resultTy = cast(reshape.getType()); - bool isDynamic = !operandTy.hasStaticShape(); - - if (isDynamic && resultTy.getRank() != 1) { - return rewriter.notifyMatchFailure( - reshape, "Cannot collapse dynamic dims to more than one dimension"); - } - - SmallVector reassociationMap; - if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(), - resultTy.getShape(), - reassociationMap, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, - "tosa.reshape Attempting to collapse into an incompatible shape"); - } +Value createCollapse(ConversionPatternRewriter &rewriter, Location loc, + ShapedType resultTy, Value operand) { + ShapedType operandTy = cast(operand.getType()); + if (resultTy == operandTy) + return operand; + + bool isDynamic = !operandTy.hasStaticShape(); + + if (isDynamic && resultTy.getRank() != 1) { + (void)rewriter.notifyMatchFailure( + loc, "Cannot collapse dynamic dims to more than one dimension"); + return {}; + } - SmallVector intermediateShape; - if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), - intermediateShape, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, "tosa.reshape Cannot collapse into given shape"); - } + SmallVector reassociationMap; + if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(), + resultTy.getShape(), + reassociationMap, isDynamic)) { + (void)rewriter.notifyMatchFailure( + loc, "tosa.reshape Attempting to collapse into an incompatible shape"); + return {}; + } - rewriter.replaceOpWithNewOp( - reshape, resultTy, adaptor.getOperands()[0], reassociationMap); - return success(); + SmallVector intermediateShape; + if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), + intermediateShape, isDynamic)) { + (void)rewriter.notifyMatchFailure( + loc, "tosa.reshape Cannot collapse into given shape"); + return {}; } -}; + return rewriter.create(loc, resultTy, operand, + reassociationMap); +} -class ReshapeConverterExpand : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +Value createExpand(ConversionPatternRewriter &rewriter, Location loc, + ShapedType resultTy, Value operand) { + ShapedType operandTy = cast(operand.getType()); + if (resultTy == operandTy) + return operand; - LogicalResult - matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = cast(adaptor.getInput1().getType()); - ShapedType resultTy = cast(reshape.getType()); - bool isDynamic = !operandTy.hasStaticShape(); + bool isDynamic = !operandTy.hasStaticShape(); - if (isDynamic && operandTy.getRank() != 1) { - return rewriter.notifyMatchFailure( - reshape, "Cannot expand dynamic dims from more than one dimension"); - } + if (isDynamic && operandTy.getRank() != 1) { + (void)rewriter.notifyMatchFailure( + loc, "Cannot expand dynamic dims from more than one dimension"); + return {}; + } - SmallVector reassociationMap; - if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(), - operandTy.getShape(), - reassociationMap, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, - "tosa.reshape Attempting to expand into an incompatible shape"); - } + SmallVector reassociationMap; + if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(), + operandTy.getShape(), + reassociationMap, isDynamic)) { + (void)rewriter.notifyMatchFailure( + loc, "tosa.reshape Attempting to expand into an incompatible shape"); + return {}; + } - SmallVector intermediateShape; - if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), - intermediateShape, isDynamic) || - intermediateShape != operandTy.getShape()) { - return rewriter.notifyMatchFailure( - reshape, "tosa.reshape Cannot expand into given shape"); - } - rewriter.replaceOpWithNewOp( - reshape, resultTy, adaptor.getOperands()[0], reassociationMap); - return success(); + SmallVector intermediateShape; + if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), + intermediateShape, isDynamic) || + intermediateShape != operandTy.getShape()) { + (void)rewriter.notifyMatchFailure( + loc, "tosa.reshape Cannot expand into given shape"); + return {}; } -}; + return rewriter.create(loc, resultTy, operand, + reassociationMap); +} class ReshapeConverterCollapseExpand : public OpConversionPattern { @@ -224,17 +217,19 @@ public: reshape, "tosa.reshape Cannot identify an intermediate shape between " "the given two shapes"); } + auto intermediateTy = RankedTensorType::get( + intermediateShape, reshape.getType().getElementType()); - Value collapse = rewriter.create( - reshape.getLoc(), - RankedTensorType::get(intermediateShape, - reshape.getType().getElementType()), - adaptor.getInput1(), rewriter.getDenseI64ArrayAttr(intermediateShape)); - Value expand = rewriter.create( - reshape.getLoc(), resultTy, collapse, - rewriter.getDenseI64ArrayAttr(resultTy.getShape())); - rewriter.replaceOp(reshape, expand); + Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy, + adaptor.getInput1()); + if (!collapse) + return failure(); + Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse); + if (!expand) + return failure(); + + rewriter.replaceOp(reshape, expand); return success(); } }; @@ -420,10 +415,6 @@ void mlir::tosa::populateTosaToTensorConversionPatterns( RewritePatternSet *patterns) { patterns->add( patterns->getContext()); - patterns->add(patterns->getContext(), - /*benefit=*/100); - patterns->add(patterns->getContext(), - /*benefit=*/200); - patterns->add(patterns->getContext(), - /*benefit=*/300); + + patterns->add(patterns->getContext()); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 8cefa64..152b885 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -62,31 +62,6 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } -struct ReshapeReshapeOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ReshapeOp op, - PatternRewriter &rewriter) const override { - Value input = op.getInput1(); - Operation *definingOp = input.getDefiningOp(); - if (!definingOp) - return failure(); - - if (tosa::ReshapeOp reshapeOp = dyn_cast(definingOp)) { - rewriter.replaceOpWithNewOp( - op, op.getType(), reshapeOp.getInput1(), op.getNewShape()); - return success(); - } - - return failure(); - } -}; - -void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { auto notOp = op.getPred().getDefiningOp(); if (!notOp) @@ -820,25 +795,32 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { if (inputTy == outputTy) return getInput1(); - // Constants must have static shape. - if (!outputTy.hasStaticShape()) - return {}; + // reshape(reshape(x)) -> reshape(x) + if (auto reshapeOp = llvm::dyn_cast_if_present( + getInput1().getDefiningOp())) { + getInput1Mutable().assign(reshapeOp.getInput1()); + return getResult(); + } - auto operand = llvm::dyn_cast_if_present(adaptor.getInput1()); - if (!operand) - return {}; + // reshape(const(x)) -> const(reshape-attr(x)) + if (auto operand = llvm::dyn_cast_if_present(adaptor.getInput1())) { + // Constants must have static shape. + if (!outputTy.hasStaticShape()) + return {}; - // Okay to duplicate splat constants. - if (operand.isSplat()) { - return SplatElementsAttr::get(outputTy, operand.getSplatValue()); - } + // Okay to duplicate splat constants. + if (operand.isSplat()) + return SplatElementsAttr::get(outputTy, operand.getSplatValue()); - // Don't duplicate other constants. - if (!getInput1().hasOneUse()) - return {}; + // Don't duplicate other constants. + if (!getInput1().hasOneUse()) + return {}; - return operand.reshape( - llvm::cast(operand.getType()).clone(getNewShape())); + return operand.reshape( + llvm::cast(operand.getType()).clone(getNewShape())); + } + + return {}; } OpFoldResult PadOp::fold(FoldAdaptor adaptor) {