#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/TypeSwitch.h"
#include <cassert>
#include <cstdint>
enum class ExtensionKind { Sign, Zero };
-ExtensionKind getExtensionKind(Operation *op) {
- assert(op);
- assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
- return isa<arith::ExtSIOp>(op) ? ExtensionKind::Sign : ExtensionKind::Zero;
-}
+/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
+/// the exact op type. Exposes helper functions to query the types, operands,
+/// and the result. This is so that we can handle both extension kinds without
+/// needing to use templates or branching.
+class ExtensionOp {
+public:
+ /// Attemps to create a new extension op from `op`. Returns an extension op
+ /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
+ /// otherwise.
+ static FailureOr<ExtensionOp> from(Operation *op) {
+ if (auto sext = dyn_cast_or_null<arith::ExtSIOp>(op))
+ return ExtensionOp{op, ExtensionKind::Sign};
+ if (auto zext = dyn_cast_or_null<arith::ExtUIOp>(op))
+ return ExtensionOp{op, ExtensionKind::Zero};
+
+ return failure();
+ }
+
+ ExtensionOp(const ExtensionOp &) = default;
+ ExtensionOp &operator=(const ExtensionOp &) = default;
+
+ /// Creates a new extension op of the same kind.
+ Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
+ Value in) {
+ if (kind == ExtensionKind::Sign)
+ return rewriter.create<arith::ExtSIOp>(loc, newType, in);
+
+ return rewriter.create<arith::ExtUIOp>(loc, newType, in);
+ }
+
+ /// Replaces `toReplace` with a new extension op of the same kind.
+ void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
+ Value in) {
+ assert(toReplace->getNumResults() == 1);
+ Type newType = toReplace->getResult(0).getType();
+ Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
+ rewriter.replaceOp(toReplace, newOp->getResult(0));
+ }
+
+ ExtensionKind getKind() { return kind; }
+
+ Value getResult() { return op->getResult(0); }
+ Value getIn() { return op->getOperand(0); }
+
+ Type getType() { return getResult().getType(); }
+ Type getElementType() { return getElementTypeOrSelf(getType()); }
+ Type getInType() { return getIn().getType(); }
+ Type getInElementType() { return getElementTypeOrSelf(getInType()); }
+
+private:
+ ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
+ assert(op);
+ assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
+ }
+ Operation *op = nullptr;
+ ExtensionKind kind = {};
+};
/// Returns the integer bitwidth required to represent `value`.
unsigned calculateBitsRequired(const APInt &value,
LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
- Operation *def = op.getVector().getDefiningOp();
- if (!def)
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getVector().getDefiningOp());
+ if (failed(ext))
return failure();
- return TypeSwitch<Operation *, LogicalResult>(def)
- .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
- Value newExtract = rewriter.create<vector::ExtractOp>(
- op.getLoc(), extOp.getIn(), op.getPosition());
- rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
- newExtract);
- return success();
- })
- .Default(failure());
+ Value newExtract = rewriter.create<vector::ExtractOp>(
+ op.getLoc(), ext->getIn(), op.getPosition());
+ ext->recreateAndReplace(rewriter, op, newExtract);
+ return success();
}
};
LogicalResult matchAndRewrite(vector::ExtractElementOp op,
PatternRewriter &rewriter) const override {
- Operation *def = op.getVector().getDefiningOp();
- if (!def)
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getVector().getDefiningOp());
+ if (failed(ext))
return failure();
- return TypeSwitch<Operation *, LogicalResult>(def)
- .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
- Value newExtract = rewriter.create<vector::ExtractElementOp>(
- op.getLoc(), extOp.getIn(), op.getPosition());
- rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
- newExtract);
- return success();
- })
- .Default(failure());
+ Value newExtract = rewriter.create<vector::ExtractElementOp>(
+ op.getLoc(), ext->getIn(), op.getPosition());
+ ext->recreateAndReplace(rewriter, op, newExtract);
+ return success();
}
};
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
- Operation *def = op.getVector().getDefiningOp();
- if (!def)
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getVector().getDefiningOp());
+ if (failed(ext))
return failure();
- return TypeSwitch<Operation *, LogicalResult>(def)
- .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
- VectorType origTy = op.getType();
- Type inElemTy =
- cast<VectorType>(extOp.getIn().getType()).getElementType();
- VectorType extractTy = origTy.cloneWith(origTy.getShape(), inElemTy);
- Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
- op.getLoc(), extractTy, extOp.getIn(), op.getOffsets(),
- op.getSizes(), op.getStrides());
- rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
- newExtract);
- return success();
- })
- .Default(failure());
+ VectorType origTy = op.getType();
+ VectorType extractTy =
+ origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+ Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
+ op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
+ op.getStrides());
+ ext->recreateAndReplace(rewriter, op, newExtract);
+ return success();
}
};
LogicalResult matchAndRewrite(vector::InsertOp op,
PatternRewriter &rewriter) const override {
- Operation *def = op.getSource().getDefiningOp();
- if (!def)
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getSource().getDefiningOp());
+ if (failed(ext))
return failure();
- return TypeSwitch<Operation *, LogicalResult>(def)
- .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
- // Rewrite the insertion in terms of narrower operands
- // and later extend the result to the original bitwidth.
- FailureOr<vector::InsertOp> newInsert =
- createNarrowInsert(op, rewriter, extOp);
- if (failed(newInsert))
- return failure();
- rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
- *newInsert);
- return success();
- })
- .Default(failure());
+ FailureOr<vector::InsertOp> newInsert =
+ createNarrowInsert(op, rewriter, *ext);
+ if (failed(newInsert))
+ return failure();
+ ext->recreateAndReplace(rewriter, op, *newInsert);
+ return success();
}
FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
PatternRewriter &rewriter,
- Operation *insValue) const {
- assert((isa<arith::ExtSIOp, arith::ExtUIOp>(insValue)));
-
+ ExtensionOp insValue) const {
// Calculate the operand and result bitwidths. We can only apply narrowing
// when the inserted source value and destination vector require fewer bits
// than the result. Because the source and destination may have different
if (failed(origBitsRequired))
return failure();
- ExtensionKind kind = getExtensionKind(insValue);
FailureOr<unsigned> destBitsRequired =
- calculateBitsRequired(op.getDest(), kind);
+ calculateBitsRequired(op.getDest(), insValue.getKind());
if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
return failure();
FailureOr<unsigned> insertedBitsRequired =
- calculateBitsRequired(insValue->getOperands().front(), kind);
+ calculateBitsRequired(insValue.getIn(), insValue.getKind());
if (failed(insertedBitsRequired) ||
*insertedBitsRequired >= *origBitsRequired)
return failure();
return failure();
FailureOr<Type> newInsertedValueTy =
- getNarrowType(newInsertionBits, insValue->getResultTypes().front());
+ getNarrowType(newInsertionBits, insValue.getType());
if (failed(newInsertedValueTy))
return failure();
Location loc = op.getLoc();
Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
- loc, *newInsertedValueTy, insValue->getResult(0));
+ loc, *newInsertedValueTy, insValue.getResult());
Value narrowDest =
rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,