From 49df068836bfbb538771395d8bb293548afd414e Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 1 May 2023 13:31:30 -0400 Subject: [PATCH] [mlir][arith][NFC] Simplify narrowing patterns with a wrapper type Add a new wraper type that represents either of `ExtSIOp` or `ExtUIOp`. This is to simplify the code by using a single type, so that we do not have to use templates or branching to handle both extension kinds. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D149485 --- mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp | 164 ++++++++++++--------- 1 file changed, 97 insertions(+), 67 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index 639b19b..c515824 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -15,13 +15,13 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/TypeSwitch.h" #include #include @@ -100,11 +100,63 @@ FailureOr calculateBitsRequired(Type type) { enum class ExtensionKind { Sign, Zero }; -ExtensionKind getExtensionKind(Operation *op) { - assert(op); - assert((isa(op)) && "Not an extension op"); - return isa(op) ? ExtensionKind::Sign : ExtensionKind::Zero; -} +/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away +/// the exact op type. Exposes helper functions to query the types, operands, +/// and the result. This is so that we can handle both extension kinds without +/// needing to use templates or branching. +class ExtensionOp { +public: + /// Attemps to create a new extension op from `op`. Returns an extension op + /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure + /// otherwise. + static FailureOr from(Operation *op) { + if (auto sext = dyn_cast_or_null(op)) + return ExtensionOp{op, ExtensionKind::Sign}; + if (auto zext = dyn_cast_or_null(op)) + return ExtensionOp{op, ExtensionKind::Zero}; + + return failure(); + } + + ExtensionOp(const ExtensionOp &) = default; + ExtensionOp &operator=(const ExtensionOp &) = default; + + /// Creates a new extension op of the same kind. + Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType, + Value in) { + if (kind == ExtensionKind::Sign) + return rewriter.create(loc, newType, in); + + return rewriter.create(loc, newType, in); + } + + /// Replaces `toReplace` with a new extension op of the same kind. + void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace, + Value in) { + assert(toReplace->getNumResults() == 1); + Type newType = toReplace->getResult(0).getType(); + Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in); + rewriter.replaceOp(toReplace, newOp->getResult(0)); + } + + ExtensionKind getKind() { return kind; } + + Value getResult() { return op->getResult(0); } + Value getIn() { return op->getOperand(0); } + + Type getType() { return getResult().getType(); } + Type getElementType() { return getElementTypeOrSelf(getType()); } + Type getInType() { return getIn().getType(); } + Type getInElementType() { return getElementTypeOrSelf(getInType()); } + +private: + ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) { + assert(op); + assert((isa(op)) && "Not an extension op"); + } + Operation *op = nullptr; + ExtensionKind kind = {}; +}; /// Returns the integer bitwidth required to represent `value`. unsigned calculateBitsRequired(const APInt &value, @@ -202,19 +254,15 @@ struct ExtensionOverExtract final : NarrowingPattern { LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { - Operation *def = op.getVector().getDefiningOp(); - if (!def) + FailureOr ext = + ExtensionOp::from(op.getVector().getDefiningOp()); + if (failed(ext)) return failure(); - return TypeSwitch(def) - .Case([&](auto extOp) { - Value newExtract = rewriter.create( - op.getLoc(), extOp.getIn(), op.getPosition()); - rewriter.replaceOpWithNewOp(op, op.getType(), - newExtract); - return success(); - }) - .Default(failure()); + Value newExtract = rewriter.create( + op.getLoc(), ext->getIn(), op.getPosition()); + ext->recreateAndReplace(rewriter, op, newExtract); + return success(); } }; @@ -224,19 +272,15 @@ struct ExtensionOverExtractElement final LogicalResult matchAndRewrite(vector::ExtractElementOp op, PatternRewriter &rewriter) const override { - Operation *def = op.getVector().getDefiningOp(); - if (!def) + FailureOr ext = + ExtensionOp::from(op.getVector().getDefiningOp()); + if (failed(ext)) return failure(); - return TypeSwitch(def) - .Case([&](auto extOp) { - Value newExtract = rewriter.create( - op.getLoc(), extOp.getIn(), op.getPosition()); - rewriter.replaceOpWithNewOp(op, op.getType(), - newExtract); - return success(); - }) - .Default(failure()); + Value newExtract = rewriter.create( + op.getLoc(), ext->getIn(), op.getPosition()); + ext->recreateAndReplace(rewriter, op, newExtract); + return success(); } }; @@ -246,24 +290,19 @@ struct ExtensionOverExtractStridedSlice final LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { - Operation *def = op.getVector().getDefiningOp(); - if (!def) + FailureOr ext = + ExtensionOp::from(op.getVector().getDefiningOp()); + if (failed(ext)) return failure(); - return TypeSwitch(def) - .Case([&](auto extOp) { - VectorType origTy = op.getType(); - Type inElemTy = - cast(extOp.getIn().getType()).getElementType(); - VectorType extractTy = origTy.cloneWith(origTy.getShape(), inElemTy); - Value newExtract = rewriter.create( - op.getLoc(), extractTy, extOp.getIn(), op.getOffsets(), - op.getSizes(), op.getStrides()); - rewriter.replaceOpWithNewOp(op, op.getType(), - newExtract); - return success(); - }) - .Default(failure()); + VectorType origTy = op.getType(); + VectorType extractTy = + origTy.cloneWith(origTy.getShape(), ext->getInElementType()); + Value newExtract = rewriter.create( + op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(), + op.getStrides()); + ext->recreateAndReplace(rewriter, op, newExtract); + return success(); } }; @@ -272,30 +311,22 @@ struct ExtensionOverInsert final : NarrowingPattern { LogicalResult matchAndRewrite(vector::InsertOp op, PatternRewriter &rewriter) const override { - Operation *def = op.getSource().getDefiningOp(); - if (!def) + FailureOr ext = + ExtensionOp::from(op.getSource().getDefiningOp()); + if (failed(ext)) return failure(); - return TypeSwitch(def) - .Case([&](auto extOp) { - // Rewrite the insertion in terms of narrower operands - // and later extend the result to the original bitwidth. - FailureOr newInsert = - createNarrowInsert(op, rewriter, extOp); - if (failed(newInsert)) - return failure(); - rewriter.replaceOpWithNewOp(op, op.getType(), - *newInsert); - return success(); - }) - .Default(failure()); + FailureOr newInsert = + createNarrowInsert(op, rewriter, *ext); + if (failed(newInsert)) + return failure(); + ext->recreateAndReplace(rewriter, op, *newInsert); + return success(); } FailureOr createNarrowInsert(vector::InsertOp op, PatternRewriter &rewriter, - Operation *insValue) const { - assert((isa(insValue))); - + ExtensionOp insValue) const { // Calculate the operand and result bitwidths. We can only apply narrowing // when the inserted source value and destination vector require fewer bits // than the result. Because the source and destination may have different @@ -306,14 +337,13 @@ struct ExtensionOverInsert final : NarrowingPattern { if (failed(origBitsRequired)) return failure(); - ExtensionKind kind = getExtensionKind(insValue); FailureOr destBitsRequired = - calculateBitsRequired(op.getDest(), kind); + calculateBitsRequired(op.getDest(), insValue.getKind()); if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired) return failure(); FailureOr insertedBitsRequired = - calculateBitsRequired(insValue->getOperands().front(), kind); + calculateBitsRequired(insValue.getIn(), insValue.getKind()); if (failed(insertedBitsRequired) || *insertedBitsRequired >= *origBitsRequired) return failure(); @@ -327,13 +357,13 @@ struct ExtensionOverInsert final : NarrowingPattern { return failure(); FailureOr newInsertedValueTy = - getNarrowType(newInsertionBits, insValue->getResultTypes().front()); + getNarrowType(newInsertionBits, insValue.getType()); if (failed(newInsertedValueTy)) return failure(); Location loc = op.getLoc(); Value narrowValue = rewriter.createOrFold( - loc, *newInsertedValueTy, insValue->getResult(0)); + loc, *newInsertedValueTy, insValue.getResult()); Value narrowDest = rewriter.createOrFold(loc, *newVecTy, op.getDest()); return rewriter.create(loc, narrowValue, narrowDest, -- 2.7.4