[mlir][arith][NFC] Simplify narrowing patterns with a wrapper type
authorJakub Kuderski <kubak@google.com>
Mon, 1 May 2023 17:31:30 +0000 (13:31 -0400)
committerJakub Kuderski <kubak@google.com>
Mon, 1 May 2023 17:31:31 +0000 (13:31 -0400)
Add a new wraper type that represents either of `ExtSIOp` or `ExtUIOp`.
This is to simplify the code by using a single type, so that we do not
have to use templates or branching to handle both extension kinds.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149485

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

index 639b19b..c515824 100644 (file)
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/TypeSwitch.h"
 #include <cassert>
 #include <cstdint>
 
@@ -100,11 +100,63 @@ FailureOr<unsigned> calculateBitsRequired(Type type) {
 
 enum class ExtensionKind { Sign, Zero };
 
-ExtensionKind getExtensionKind(Operation *op) {
-  assert(op);
-  assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
-  return isa<arith::ExtSIOp>(op) ? ExtensionKind::Sign : ExtensionKind::Zero;
-}
+/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
+/// the exact op type. Exposes helper functions to query the types, operands,
+/// and the result. This is so that we can handle both extension kinds without
+/// needing to use templates or branching.
+class ExtensionOp {
+public:
+  /// Attemps to create a new extension op from `op`. Returns an extension op
+  /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
+  /// otherwise.
+  static FailureOr<ExtensionOp> from(Operation *op) {
+    if (auto sext = dyn_cast_or_null<arith::ExtSIOp>(op))
+      return ExtensionOp{op, ExtensionKind::Sign};
+    if (auto zext = dyn_cast_or_null<arith::ExtUIOp>(op))
+      return ExtensionOp{op, ExtensionKind::Zero};
+
+    return failure();
+  }
+
+  ExtensionOp(const ExtensionOp &) = default;
+  ExtensionOp &operator=(const ExtensionOp &) = default;
+
+  /// Creates a new extension op of the same kind.
+  Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
+                      Value in) {
+    if (kind == ExtensionKind::Sign)
+      return rewriter.create<arith::ExtSIOp>(loc, newType, in);
+
+    return rewriter.create<arith::ExtUIOp>(loc, newType, in);
+  }
+
+  /// Replaces `toReplace` with a new extension op of the same kind.
+  void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
+                          Value in) {
+    assert(toReplace->getNumResults() == 1);
+    Type newType = toReplace->getResult(0).getType();
+    Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
+    rewriter.replaceOp(toReplace, newOp->getResult(0));
+  }
+
+  ExtensionKind getKind() { return kind; }
+
+  Value getResult() { return op->getResult(0); }
+  Value getIn() { return op->getOperand(0); }
+
+  Type getType() { return getResult().getType(); }
+  Type getElementType() { return getElementTypeOrSelf(getType()); }
+  Type getInType() { return getIn().getType(); }
+  Type getInElementType() { return getElementTypeOrSelf(getInType()); }
+
+private:
+  ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
+    assert(op);
+    assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
+  }
+  Operation *op = nullptr;
+  ExtensionKind kind = {};
+};
 
 /// Returns the integer bitwidth required to represent `value`.
 unsigned calculateBitsRequired(const APInt &value,
@@ -202,19 +254,15 @@ struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
 
   LogicalResult matchAndRewrite(vector::ExtractOp op,
                                 PatternRewriter &rewriter) const override {
-    Operation *def = op.getVector().getDefiningOp();
-    if (!def)
+    FailureOr<ExtensionOp> ext =
+        ExtensionOp::from(op.getVector().getDefiningOp());
+    if (failed(ext))
       return failure();
 
-    return TypeSwitch<Operation *, LogicalResult>(def)
-        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
-          Value newExtract = rewriter.create<vector::ExtractOp>(
-              op.getLoc(), extOp.getIn(), op.getPosition());
-          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
-                                                       newExtract);
-          return success();
-        })
-        .Default(failure());
+    Value newExtract = rewriter.create<vector::ExtractOp>(
+        op.getLoc(), ext->getIn(), op.getPosition());
+    ext->recreateAndReplace(rewriter, op, newExtract);
+    return success();
   }
 };
 
@@ -224,19 +272,15 @@ struct ExtensionOverExtractElement final
 
   LogicalResult matchAndRewrite(vector::ExtractElementOp op,
                                 PatternRewriter &rewriter) const override {
-    Operation *def = op.getVector().getDefiningOp();
-    if (!def)
+    FailureOr<ExtensionOp> ext =
+        ExtensionOp::from(op.getVector().getDefiningOp());
+    if (failed(ext))
       return failure();
 
-    return TypeSwitch<Operation *, LogicalResult>(def)
-        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
-          Value newExtract = rewriter.create<vector::ExtractElementOp>(
-              op.getLoc(), extOp.getIn(), op.getPosition());
-          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
-                                                       newExtract);
-          return success();
-        })
-        .Default(failure());
+    Value newExtract = rewriter.create<vector::ExtractElementOp>(
+        op.getLoc(), ext->getIn(), op.getPosition());
+    ext->recreateAndReplace(rewriter, op, newExtract);
+    return success();
   }
 };
 
@@ -246,24 +290,19 @@ struct ExtensionOverExtractStridedSlice final
 
   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
-    Operation *def = op.getVector().getDefiningOp();
-    if (!def)
+    FailureOr<ExtensionOp> ext =
+        ExtensionOp::from(op.getVector().getDefiningOp());
+    if (failed(ext))
       return failure();
 
-    return TypeSwitch<Operation *, LogicalResult>(def)
-        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
-          VectorType origTy = op.getType();
-          Type inElemTy =
-              cast<VectorType>(extOp.getIn().getType()).getElementType();
-          VectorType extractTy = origTy.cloneWith(origTy.getShape(), inElemTy);
-          Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
-              op.getLoc(), extractTy, extOp.getIn(), op.getOffsets(),
-              op.getSizes(), op.getStrides());
-          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
-                                                       newExtract);
-          return success();
-        })
-        .Default(failure());
+    VectorType origTy = op.getType();
+    VectorType extractTy =
+        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+    Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
+        op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
+        op.getStrides());
+    ext->recreateAndReplace(rewriter, op, newExtract);
+    return success();
   }
 };
 
@@ -272,30 +311,22 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
 
   LogicalResult matchAndRewrite(vector::InsertOp op,
                                 PatternRewriter &rewriter) const override {
-    Operation *def = op.getSource().getDefiningOp();
-    if (!def)
+    FailureOr<ExtensionOp> ext =
+        ExtensionOp::from(op.getSource().getDefiningOp());
+    if (failed(ext))
       return failure();
 
-    return TypeSwitch<Operation *, LogicalResult>(def)
-        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
-          // Rewrite the insertion in terms of narrower operands
-          // and later extend the result to the original bitwidth.
-          FailureOr<vector::InsertOp> newInsert =
-              createNarrowInsert(op, rewriter, extOp);
-          if (failed(newInsert))
-            return failure();
-          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
-                                                       *newInsert);
-          return success();
-        })
-        .Default(failure());
+    FailureOr<vector::InsertOp> newInsert =
+        createNarrowInsert(op, rewriter, *ext);
+    if (failed(newInsert))
+      return failure();
+    ext->recreateAndReplace(rewriter, op, *newInsert);
+    return success();
   }
 
   FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
                                                  PatternRewriter &rewriter,
-                                                 Operation *insValue) const {
-    assert((isa<arith::ExtSIOp, arith::ExtUIOp>(insValue)));
-
+                                                 ExtensionOp insValue) const {
     // Calculate the operand and result bitwidths. We can only apply narrowing
     // when the inserted source value and destination vector require fewer bits
     // than the result. Because the source and destination may have different
@@ -306,14 +337,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
     if (failed(origBitsRequired))
       return failure();
 
-    ExtensionKind kind = getExtensionKind(insValue);
     FailureOr<unsigned> destBitsRequired =
-        calculateBitsRequired(op.getDest(), kind);
+        calculateBitsRequired(op.getDest(), insValue.getKind());
     if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
       return failure();
 
     FailureOr<unsigned> insertedBitsRequired =
-        calculateBitsRequired(insValue->getOperands().front(), kind);
+        calculateBitsRequired(insValue.getIn(), insValue.getKind());
     if (failed(insertedBitsRequired) ||
         *insertedBitsRequired >= *origBitsRequired)
       return failure();
@@ -327,13 +357,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
       return failure();
 
     FailureOr<Type> newInsertedValueTy =
-        getNarrowType(newInsertionBits, insValue->getResultTypes().front());
+        getNarrowType(newInsertionBits, insValue.getType());
     if (failed(newInsertedValueTy))
       return failure();
 
     Location loc = op.getLoc();
     Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
-        loc, *newInsertedValueTy, insValue->getResult(0));
+        loc, *newInsertedValueTy, insValue.getResult());
     Value narrowDest =
         rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
     return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,