From 1b7feac2a6c42f5f4302579eeafbe904f5ccf972 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 22 Jul 2022 07:20:24 -0700 Subject: [PATCH] [mlir][tosa] Split canonicalization and folders out of TosaOps. Scope ops file to ops. Used canonicalization as grouping for canonicalization patterns and folders (also considered OpTransforms but that felt too generic and the former two are used together). Reviewed By: silvas, rsuderman Differential Revision: https://reviews.llvm.org/D130297 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 11 - mlir/lib/Dialect/Tosa/CMakeLists.txt | 3 +- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 543 +++++++++++++++++++++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 529 -------------------- .../Transforms/TosaLayerwiseConstantFoldPass.cpp | 16 +- 5 files changed, 560 insertions(+), 542 deletions(-) create mode 100644 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index ff4225b..afdd801 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -34,17 +34,6 @@ namespace tosa { } // namespace tosa } // namespace mlir -//===----------------------------------------------------------------------===// -// Utility Functions -//===----------------------------------------------------------------------===// -namespace mlir { -namespace tosa { -/// Appends the canonicalization patterns for all the TOSA ops to the `patterns` -void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx, - RewritePatternSet &patterns); -} // namespace tosa -} // namespace mlir - #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc" diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index 520f642..77e9051 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -1,7 +1,8 @@ add_mlir_dialect_library(MLIRTosaDialect + IR/TosaOps.cpp + IR/TosaCanonicalizations.cpp Utils/ConversionUtils.cpp Utils/QuantUtils.cpp - IR/TosaOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp new file mode 100644 index 0000000..7bb6339 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -0,0 +1,543 @@ +//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// TOSA canonicalization patterns and folders. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::tosa; + +//===----------------------------------------------------------------------===// +// Operator Canonicalizers. +//===----------------------------------------------------------------------===// + +struct ConcatOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ConcatOp op, + PatternRewriter &rewriter) const override { + if (op.input1().size() != 1) + return failure(); + if (op.input1().front().getType() != op.getType()) { + rewriter + .replaceOpWithNewOp(op, op.getType(), + op.input1().front()) + .getResult(); + return success(); + } + + rewriter.replaceOp(op, op.input1().front()); + return success(); + } +}; + +void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct ReshapeReshapeOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ReshapeOp op, + PatternRewriter &rewriter) const override { + Value input = op.input1(); + Operation *definingOp = input.getDefiningOp(); + if (!definingOp) + return failure(); + + if (tosa::ReshapeOp reshapeOp = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp( + op, op.getType(), reshapeOp.input1(), op.new_shape()); + return success(); + } + + return failure(); + } +}; + +struct ReshapeConstOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ReshapeOp op, + PatternRewriter &rewriter) const override { + Value input = op.input1(); + ArrayAttr newShape = op.new_shape(); + + // Check if input is constant + DenseElementsAttr inputAttr; + if (!matchPattern(input, m_Constant(&inputAttr))) + return failure(); + + // Check if has >1 consumer and is not splat + if (!input.hasOneUse() && !inputAttr.isSplat()) + return failure(); + + // Grab the new shape + SmallVector newShapeValues = llvm::to_vector<6>( + llvm::map_range(newShape.getValue(), [](const Attribute &val) { + return val.cast().getValue().getSExtValue(); + })); + + // Build new const op with correct output shape + ShapedType inputShape = input.getType().cast(); + DenseElementsAttr outputAttr = + inputAttr.reshape(inputShape.clone(newShapeValues)); + rewriter.replaceOpWithNewOp(op, outputAttr.getType(), + outputAttr); + return success(); + } +}; + +void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); + results.add(context); +} + +LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { + auto notOp = op.pred().getDefiningOp(); + if (!notOp) + return failure(); + rewriter.updateRootInPlace(op, [&]() { + op.getOperation()->setOperands( + {notOp.input1(), op.on_false(), op.on_true()}); + }); + return success(); +} + +struct NoOpOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TransposeOp op, + PatternRewriter &rewriter) const override { + auto perm = op.perms(); + + DenseIntElementsAttr permAttr; + if (!matchPattern(perm, m_Constant(&permAttr))) { + return failure(); + } + + SmallVector permValues = llvm::to_vector<6>( + llvm::map_range(permAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + + for (int i = 0, s = permValues.size(); i < s; i++) { + if (i != permValues[i]) { + return failure(); + } + } + + rewriter.replaceOp(op, op.input1()); + return success(); + } +}; + +void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct AddZeroOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::AddOp op, + PatternRewriter &rewriter) const override { + auto input1 = op.input1(); + auto input2 = op.input2(); + + DenseElementsAttr input1Attr; + if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && + input2.getType() == op.getType()) { + if (input1Attr.getType().getElementType().isa() && + input1Attr.getSplatValue().isZero()) { + rewriter.replaceOp(op, op.input2()); + return success(); + } + } + + DenseElementsAttr input2Attr; + if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && + input1.getType() == op.getType()) { + if (input2Attr.getType().getElementType().isa() && + input2Attr.getSplatValue().isZero()) { + rewriter.replaceOp(op, op.input1()); + return success(); + } + } + + return failure(); + } +}; + +void AddOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct MulOneOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MulOp op, + PatternRewriter &rewriter) const override { + auto input1 = op.input1(); + auto input2 = op.input2(); + + DenseElementsAttr input1Attr; + if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && + input2.getType() == op.getType()) { + if (input1Attr.getType().getElementType().isa() && + input1Attr.getSplatValue().isExactlyValue(1)) { + rewriter.replaceOp(op, op.input2()); + return success(); + } + + if (input1Attr.getType().getElementType().isa() && + matchPattern(input1, m_One())) { + rewriter.replaceOp(op, op.input2()); + return success(); + } + } + + DenseElementsAttr input2Attr; + if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && + input1.getType() == op.getType()) { + if (input2Attr.getType().getElementType().isa() && + input2Attr.getSplatValue().isExactlyValue(1)) { + rewriter.replaceOp(op, op.input1()); + return success(); + } + + if (input2Attr.getType().getElementType().isa() && + matchPattern(input2, m_One())) { + rewriter.replaceOp(op, op.input1()); + return success(); + } + } + + return failure(); + } +}; + +void MulOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct MaterializePadValue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::PadOp op, + PatternRewriter &rewriter) const override { + if (op.pad_const()) + return failure(); + + auto input = op.input1(); + auto padding = op.padding(); + + ShapedType inputTy = input.getType().cast(); + Type elementTy = inputTy.getElementType(); + + Attribute constantAttr; + if (elementTy.isa()) { + constantAttr = rewriter.getFloatAttr(elementTy, 0.0); + } else if (elementTy.isa() && !op.quantization_info()) { + constantAttr = rewriter.getIntegerAttr(elementTy, 0); + } else if (elementTy.isa() && op.quantization_info()) { + auto value = op.quantization_info()->getInputZp(); + constantAttr = rewriter.getIntegerAttr(elementTy, value); + } + + if (!constantAttr) { + return rewriter.notifyMatchFailure( + op, + "tosa.pad to linalg lowering encountered an unknown element type"); + } + + auto denseAttr = DenseElementsAttr::get( + RankedTensorType::get({}, elementTy), constantAttr); + auto constantVal = rewriter.create( + op.getLoc(), denseAttr.getType(), denseAttr); + + rewriter.replaceOpWithNewOp( + op, op.getType(), ValueRange{input, padding, constantVal}, + op->getAttrs()); + return success(); + } +}; + +void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct MaxPool2dIsNoOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + Value output = op.output(); + ShapedType inputType = input.getType().cast(); + ShapedType outputType = output.getType().cast(); + + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { + return failure(); + } + + // If the output and input shapes are 1x1, then this is a no op. + ArrayRef outputShape = outputType.getShape(); + if (outputShape[1] != 1 || outputShape[2] != 1) { + return failure(); + } + + ArrayRef inputShape = inputType.getShape(); + if (inputShape[1] != 1 || inputShape[2] != 1) { + return failure(); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; + +void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct ClampIsNoOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ClampOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + auto inputType = op.input().getType().template dyn_cast(); + auto inputElementType = inputType.getElementType(); + + if (!inputType.hasStaticShape()) { + return failure(); + } + + if (inputElementType.isF32()) { + auto minClamp = op.min_fp(); + auto maxClamp = op.max_fp(); + bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) && + minClamp.isNegative(); + bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) && + !maxClamp.isNegative(); + + if (isMin && isMax) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } + + if (inputElementType.isUnsignedInteger()) { + int64_t minClamp = op.min_int(); + int64_t maxClamp = op.max_int(); + + int64_t intMin = + APInt::getMinValue(inputElementType.getIntOrFloatBitWidth()) + .getZExtValue(); + int64_t intMax = + APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth()) + .getZExtValue(); + + if (minClamp <= intMin && maxClamp >= intMax) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } + + if (inputElementType.isa()) { + int64_t minClamp = op.min_int(); + int64_t maxClamp = op.max_int(); + + int64_t intMin = + APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth()) + .getSExtValue(); + int64_t intMax = + APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth()) + .getSExtValue(); + + if (minClamp <= intMin && maxClamp >= intMax) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } + + return failure(); + } +}; + +struct ClampClampOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ClampOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + + Operation *definingOp = input.getDefiningOp(); + if (!definingOp) + return failure(); + + if (tosa::ClampOp clampOp = dyn_cast(definingOp)) { + auto minFp = std::max(op.min_fp(), clampOp.min_fp()).convertToFloat(); + auto maxFp = std::min(op.max_fp(), clampOp.max_fp()).convertToFloat(); + + auto minInt = std::max(op.min_int(), clampOp.min_int()); + auto maxInt = std::min(op.max_int(), clampOp.max_int()); + + rewriter.replaceOpWithNewOp( + op, op.getType(), clampOp.input(), rewriter.getI64IntegerAttr(minInt), + rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), + rewriter.getF32FloatAttr(maxFp)); + return success(); + } + + return failure(); + } +}; + +void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); + results.add(context); +} + +//===----------------------------------------------------------------------===// +// Operator Folders. +//===----------------------------------------------------------------------===// + +OpFoldResult CastOp::fold(ArrayRef operands) { + if (input().getType() == getType()) + return input(); + return {}; +} + +OpFoldResult ConstOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + return valueAttr(); +} + +#define REDUCE_FOLDER(OP) \ + OpFoldResult OP::fold(ArrayRef operands) { \ + ShapedType inputTy = input().getType().cast(); \ + if (!inputTy.hasRank()) \ + return {}; \ + if (inputTy.getDimSize(axis()) == 1) \ + return input(); \ + return {}; \ + } + +REDUCE_FOLDER(ReduceAllOp) +REDUCE_FOLDER(ReduceAnyOp) +REDUCE_FOLDER(ReduceMaxOp) +REDUCE_FOLDER(ReduceMinOp) +REDUCE_FOLDER(ReduceProdOp) +REDUCE_FOLDER(ReduceSumOp) +#undef REDUCE_FOLDER + +OpFoldResult ReshapeOp::fold(ArrayRef operands) { + auto inputTy = input1().getType().dyn_cast(); + auto outputTy = getType().dyn_cast(); + + if (!inputTy || !outputTy || inputTy != outputTy) + return {}; + return input1(); +} + +OpFoldResult PadOp::fold(ArrayRef operands) { + // If the pad is all zeros we can fold this operation away. + if (operands[1]) { + auto densePad = operands[1].cast(); + if (densePad.isSplat() && densePad.getSplatValue().isZero()) { + return input1(); + } + } + + return {}; +} + +OpFoldResult SliceOp::fold(ArrayRef operands) { + auto inputTy = input().getType().dyn_cast(); + auto outputTy = getType().dyn_cast(); + + if (!inputTy || !outputTy || inputTy != outputTy) + return {}; + if (inputTy.hasStaticShape()) + return input(); + + return {}; +} + +OpFoldResult tosa::SelectOp::fold(ArrayRef operands) { + if (on_true() == on_false()) + return on_true(); + + auto predicate = operands[0].dyn_cast_or_null(); + if (!predicate) + return {}; + + if (!predicate.isSplat()) + return {}; + return predicate.getSplatValue().getBoolValue() ? on_true() + : on_false(); +} + +OpFoldResult TileOp::fold(ArrayRef operands) { + bool allOnes = true; + for (Attribute val : multiples().getValue()) { + allOnes = allOnes && val.cast().getValue().getSExtValue() == 1; + } + + if (allOnes && input1().getType() == getType()) + return input1(); + return {}; +} + +OpFoldResult TransposeOp::fold(ArrayRef operands) { + if (!operands[1]) + return {}; + + // Transposing splat values just means reshaping. + if (auto input = operands[0].dyn_cast_or_null()) { + if (input.isSplat()) + return input.reshape(getType().cast()); + } + + auto perms = llvm::to_vector<6>(llvm::map_range( + operands[1].cast().getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + + if (llvm::equal(llvm::seq(0, perms.size()), perms) && + input1().getType() == getType()) + return input1(); + return {}; +} diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 93fddb5..38a067b 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -21,9 +21,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" -#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" @@ -97,533 +95,6 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, } //===----------------------------------------------------------------------===// -// Operator Canonicalizers. -//===----------------------------------------------------------------------===// - -template -void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) { - (void)std::initializer_list{ - 0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...}; -} - -void mlir::tosa::populateTosaOpsCanonicalizationPatterns( - MLIRContext *ctx, RewritePatternSet &patterns) { - addOpsCanonicalizations< -#define GET_OP_LIST -#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" - >(ctx, patterns); -} - -struct ConcatOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ConcatOp op, - PatternRewriter &rewriter) const override { - if (op.input1().size() != 1) - return failure(); - if (op.input1().front().getType() != op.getType()) { - rewriter - .replaceOpWithNewOp(op, op.getType(), - op.input1().front()) - .getResult(); - return success(); - } - - rewriter.replaceOp(op, op.input1().front()); - return success(); - } -}; - -void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct ReshapeReshapeOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ReshapeOp op, - PatternRewriter &rewriter) const override { - Value input = op.input1(); - Operation *definingOp = input.getDefiningOp(); - if (!definingOp) - return failure(); - - if (tosa::ReshapeOp reshapeOp = dyn_cast(definingOp)) { - rewriter.replaceOpWithNewOp( - op, op.getType(), reshapeOp.input1(), op.new_shape()); - return success(); - } - - return failure(); - } -}; - -struct ReshapeConstOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ReshapeOp op, - PatternRewriter &rewriter) const override { - Value input = op.input1(); - ArrayAttr newShape = op.new_shape(); - - // Check if input is constant - DenseElementsAttr inputAttr; - if (!matchPattern(input, m_Constant(&inputAttr))) - return failure(); - - // Check if has >1 consumer and is not splat - if (!input.hasOneUse() && !inputAttr.isSplat()) - return failure(); - - // Grab the new shape - SmallVector newShapeValues = llvm::to_vector<6>( - llvm::map_range(newShape.getValue(), [](const Attribute &val) { - return val.cast().getValue().getSExtValue(); - })); - - // Build new const op with correct output shape - ShapedType inputShape = input.getType().cast(); - DenseElementsAttr outputAttr = - inputAttr.reshape(inputShape.clone(newShapeValues)); - rewriter.replaceOpWithNewOp(op, outputAttr.getType(), - outputAttr); - return success(); - } -}; - -void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); - results.add(context); -} - -LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { - auto notOp = op.pred().getDefiningOp(); - if (!notOp) - return failure(); - rewriter.updateRootInPlace(op, [&]() { - op.getOperation()->setOperands( - {notOp.input1(), op.on_false(), op.on_true()}); - }); - return success(); -} - -struct NoOpOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::TransposeOp op, - PatternRewriter &rewriter) const override { - auto perm = op.perms(); - - DenseIntElementsAttr permAttr; - if (!matchPattern(perm, m_Constant(&permAttr))) { - return failure(); - } - - SmallVector permValues = llvm::to_vector<6>( - llvm::map_range(permAttr.getValues(), - [](const APInt &val) { return val.getSExtValue(); })); - - for (int i = 0, s = permValues.size(); i < s; i++) { - if (i != permValues[i]) { - return failure(); - } - } - - rewriter.replaceOp(op, op.input1()); - return success(); - } -}; - -void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct AddZeroOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::AddOp op, - PatternRewriter &rewriter) const override { - auto input1 = op.input1(); - auto input2 = op.input2(); - - DenseElementsAttr input1Attr; - if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && - input2.getType() == op.getType()) { - if (input1Attr.getType().getElementType().isa() && - input1Attr.getSplatValue().isZero()) { - rewriter.replaceOp(op, op.input2()); - return success(); - } - } - - DenseElementsAttr input2Attr; - if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && - input1.getType() == op.getType()) { - if (input2Attr.getType().getElementType().isa() && - input2Attr.getSplatValue().isZero()) { - rewriter.replaceOp(op, op.input1()); - return success(); - } - } - - return failure(); - } -}; - -void AddOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct MulOneOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::MulOp op, - PatternRewriter &rewriter) const override { - auto input1 = op.input1(); - auto input2 = op.input2(); - - DenseElementsAttr input1Attr; - if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && - input2.getType() == op.getType()) { - if (input1Attr.getType().getElementType().isa() && - input1Attr.getSplatValue().isExactlyValue(1)) { - rewriter.replaceOp(op, op.input2()); - return success(); - } - - if (input1Attr.getType().getElementType().isa() && - matchPattern(input1, m_One())) { - rewriter.replaceOp(op, op.input2()); - return success(); - } - } - - DenseElementsAttr input2Attr; - if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && - input1.getType() == op.getType()) { - if (input2Attr.getType().getElementType().isa() && - input2Attr.getSplatValue().isExactlyValue(1)) { - rewriter.replaceOp(op, op.input1()); - return success(); - } - - if (input2Attr.getType().getElementType().isa() && - matchPattern(input2, m_One())) { - rewriter.replaceOp(op, op.input1()); - return success(); - } - } - - return failure(); - } -}; - -void MulOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct MaterializePadValue : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::PadOp op, - PatternRewriter &rewriter) const override { - if (op.pad_const()) - return failure(); - - auto input = op.input1(); - auto padding = op.padding(); - - ShapedType inputTy = input.getType().cast(); - Type elementTy = inputTy.getElementType(); - - Attribute constantAttr; - if (elementTy.isa()) { - constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - } else if (elementTy.isa() && !op.quantization_info()) { - constantAttr = rewriter.getIntegerAttr(elementTy, 0); - } else if (elementTy.isa() && op.quantization_info()) { - auto value = op.quantization_info()->getInputZp(); - constantAttr = rewriter.getIntegerAttr(elementTy, value); - } - - if (!constantAttr) { - return rewriter.notifyMatchFailure( - op, - "tosa.pad to linalg lowering encountered an unknown element type"); - } - - auto denseAttr = DenseElementsAttr::get( - RankedTensorType::get({}, elementTy), constantAttr); - auto constantVal = rewriter.create( - op.getLoc(), denseAttr.getType(), denseAttr); - - rewriter.replaceOpWithNewOp( - op, op.getType(), ValueRange{input, padding, constantVal}, - op->getAttrs()); - return success(); - } -}; - -void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct MaxPool2dIsNoOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - Value output = op.output(); - ShapedType inputType = input.getType().cast(); - ShapedType outputType = output.getType().cast(); - - if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { - return failure(); - } - - // If the output and input shapes are 1x1, then this is a no op. - ArrayRef outputShape = outputType.getShape(); - if (outputShape[1] != 1 || outputShape[2] != 1) { - return failure(); - } - - ArrayRef inputShape = inputType.getShape(); - if (inputShape[1] != 1 || inputShape[2] != 1) { - return failure(); - } - - rewriter.replaceOp(op, input); - return success(); - } -}; - -void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct ClampIsNoOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ClampOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - auto inputType = op.input().getType().template dyn_cast(); - auto inputElementType = inputType.getElementType(); - - if (!inputType.hasStaticShape()) { - return failure(); - } - - if (inputElementType.isF32()) { - auto minClamp = op.min_fp(); - auto maxClamp = op.max_fp(); - bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) && - minClamp.isNegative(); - bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) && - !maxClamp.isNegative(); - - if (isMin && isMax) { - rewriter.replaceOp(op, input); - return success(); - } - return failure(); - } - - if (inputElementType.isUnsignedInteger()) { - int64_t minClamp = op.min_int(); - int64_t maxClamp = op.max_int(); - - int64_t intMin = - APInt::getMinValue(inputElementType.getIntOrFloatBitWidth()) - .getZExtValue(); - int64_t intMax = - APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth()) - .getZExtValue(); - - if (minClamp <= intMin && maxClamp >= intMax) { - rewriter.replaceOp(op, input); - return success(); - } - return failure(); - } - - if (inputElementType.isa()) { - int64_t minClamp = op.min_int(); - int64_t maxClamp = op.max_int(); - - int64_t intMin = - APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth()) - .getSExtValue(); - int64_t intMax = - APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth()) - .getSExtValue(); - - if (minClamp <= intMin && maxClamp >= intMax) { - rewriter.replaceOp(op, input); - return success(); - } - return failure(); - } - - return failure(); - } -}; - -struct ClampClampOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ClampOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - - Operation *definingOp = input.getDefiningOp(); - if (!definingOp) - return failure(); - - if (tosa::ClampOp clampOp = dyn_cast(definingOp)) { - auto minFp = std::max(op.min_fp(), clampOp.min_fp()).convertToFloat(); - auto maxFp = std::min(op.max_fp(), clampOp.max_fp()).convertToFloat(); - - auto minInt = std::max(op.min_int(), clampOp.min_int()); - auto maxInt = std::min(op.max_int(), clampOp.max_int()); - - rewriter.replaceOpWithNewOp( - op, op.getType(), clampOp.input(), rewriter.getI64IntegerAttr(minInt), - rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), - rewriter.getF32FloatAttr(maxFp)); - return success(); - } - - return failure(); - } -}; - -void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); - results.add(context); -} - -//===----------------------------------------------------------------------===// -// Operator Folders. -//===----------------------------------------------------------------------===// - -OpFoldResult CastOp::fold(ArrayRef operands) { - if (input().getType() == getType()) - return input(); - return {}; -} - -OpFoldResult ConstOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); - return valueAttr(); -} - -#define REDUCE_FOLDER(OP) \ - OpFoldResult OP::fold(ArrayRef operands) { \ - ShapedType inputTy = input().getType().cast(); \ - if (!inputTy.hasRank()) \ - return {}; \ - if (inputTy.getDimSize(axis()) == 1) \ - return input(); \ - return {}; \ - } - -REDUCE_FOLDER(ReduceAllOp) -REDUCE_FOLDER(ReduceAnyOp) -REDUCE_FOLDER(ReduceMaxOp) -REDUCE_FOLDER(ReduceMinOp) -REDUCE_FOLDER(ReduceProdOp) -REDUCE_FOLDER(ReduceSumOp) -#undef REDUCE_FOLDER - -OpFoldResult ReshapeOp::fold(ArrayRef operands) { - auto inputTy = input1().getType().dyn_cast(); - auto outputTy = getType().dyn_cast(); - - if (!inputTy || !outputTy || inputTy != outputTy) - return {}; - return input1(); -} - -OpFoldResult PadOp::fold(ArrayRef operands) { - // If the pad is all zeros we can fold this operation away. - if (operands[1]) { - auto densePad = operands[1].cast(); - if (densePad.isSplat() && densePad.getSplatValue().isZero()) { - return input1(); - } - } - - return {}; -} - -OpFoldResult SliceOp::fold(ArrayRef operands) { - auto inputTy = input().getType().dyn_cast(); - auto outputTy = getType().dyn_cast(); - - if (!inputTy || !outputTy || inputTy != outputTy) - return {}; - if (inputTy.hasStaticShape()) - return input(); - - return {}; -} - -OpFoldResult tosa::SelectOp::fold(ArrayRef operands) { - if (on_true() == on_false()) - return on_true(); - - auto predicate = operands[0].dyn_cast_or_null(); - if (!predicate) - return {}; - - if (!predicate.isSplat()) - return {}; - return predicate.getSplatValue().getBoolValue() ? on_true() - : on_false(); -} - -OpFoldResult TileOp::fold(ArrayRef operands) { - bool allOnes = true; - for (Attribute val : multiples().getValue()) { - allOnes = allOnes && val.cast().getValue().getSExtValue() == 1; - } - - if (allOnes && input1().getType() == getType()) - return input1(); - return {}; -} - -OpFoldResult TransposeOp::fold(ArrayRef operands) { - if (!operands[1]) - return {}; - - // Transposing splat values just means reshaping. - if (auto input = operands[0].dyn_cast_or_null()) { - if (input.isSplat()) - return input.reshape(getType().cast()); - } - - auto perms = llvm::to_vector<6>(llvm::map_range( - operands[1].cast().getValues(), - [](const APInt &val) { return val.getSExtValue(); })); - - if (llvm::equal(llvm::seq(0, perms.size()), perms) && - input1().getType() == getType()) - return input1(); - return {}; -} - -//===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index 7cf7ff1..7814b91 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -21,6 +21,20 @@ using namespace mlir::tosa; namespace { +template +void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) { + (void)std::initializer_list{ + 0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...}; +} + +void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx, + RewritePatternSet &patterns) { + addOpsCanonicalizations< +#define GET_OP_LIST +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + >(ctx, patterns); +} + struct TosaLayerwiseConstantFoldPass : public TosaLayerwiseConstantFoldPassBase { void runOnOperation() override { @@ -29,7 +43,7 @@ struct TosaLayerwiseConstantFoldPass auto func = getOperation(); mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns); - mlir::tosa::populateTosaOpsCanonicalizationPatterns(ctx, patterns); + populateTosaOpsCanonicalizationPatterns(ctx, patterns); if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) signalPassFailure(); -- 2.7.4