[mlir][arith] Add patterns to commute extension over vector extraction
authorJakub Kuderski <kubak@google.com>
Fri, 28 Apr 2023 17:48:48 +0000 (13:48 -0400)
committerJakub Kuderski <kubak@google.com>
Fri, 28 Apr 2023 17:48:50 +0000 (13:48 -0400)
This moves zero/sign-extension ops closer to their use and exposes more
narrowing optimization opportunities.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149233

mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
mlir/test/Dialect/Arith/int-narrowing.mlir

index e6fe468..50b7484 100644 (file)
@@ -90,6 +90,7 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> {
     prefers the narrowest available integer bitwidths that are guaranteed to
     produce the same results.
   }];
+  let dependentDialects = ["vector::VectorDialect"];
   let options = [
     ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned",
                "Integer bitwidths supported">,
index e884a19..3401a9c 100644 (file)
@@ -9,16 +9,19 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include <cassert>
 #include <cstdint>
 
@@ -144,6 +147,80 @@ using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
 using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
 
 //===----------------------------------------------------------------------===//
+// Patterns to Commute Extension Ops
+//===----------------------------------------------------------------------===//
+
+struct ExtensionOverExtract final : OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *def = op.getVector().getDefiningOp();
+    if (!def)
+      return failure();
+
+    return TypeSwitch<Operation *, LogicalResult>(def)
+        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+          Value newExtract = rewriter.create<vector::ExtractOp>(
+              op.getLoc(), extOp.getIn(), op.getPosition());
+          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+                                                       newExtract);
+          return success();
+        })
+        .Default(failure());
+  }
+};
+
+struct ExtensionOverExtractElement final
+    : OpRewritePattern<vector::ExtractElementOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractElementOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *def = op.getVector().getDefiningOp();
+    if (!def)
+      return failure();
+
+    return TypeSwitch<Operation *, LogicalResult>(def)
+        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+          Value newExtract = rewriter.create<vector::ExtractElementOp>(
+              op.getLoc(), extOp.getIn(), op.getPosition());
+          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+                                                       newExtract);
+          return success();
+        })
+        .Default(failure());
+  }
+};
+
+struct ExtensionOverExtractStridedSlice final
+    : OpRewritePattern<vector::ExtractStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *def = op.getVector().getDefiningOp();
+    if (!def)
+      return failure();
+
+    return TypeSwitch<Operation *, LogicalResult>(def)
+        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+          VectorType origTy = op.getType();
+          Type inElemTy =
+              cast<VectorType>(extOp.getIn().getType()).getElementType();
+          VectorType extractTy = origTy.cloneWith(origTy.getShape(), inElemTy);
+          Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
+              op.getLoc(), extractTy, extOp.getIn(), op.getOffsets(),
+              op.getSizes(), op.getStrides());
+          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+                                                       newExtract);
+          return success();
+        })
+        .Default(failure());
+  }
+};
+
+//===----------------------------------------------------------------------===//
 // Pass Definitions
 //===----------------------------------------------------------------------===//
 
@@ -169,6 +246,12 @@ struct ArithIntNarrowingPass final
 
 void populateArithIntNarrowingPatterns(
     RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
+  // Add commute patterns with a higher benefit. This is to expose more
+  // optimization opportunities to narrowing patterns.
+  patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
+               ExtensionOverExtractStridedSlice>(patterns.getContext(),
+                                                 PatternBenefit(2));
+
   patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
 }
 
index 21d5ab7..f1290e5 100644 (file)
@@ -1,6 +1,10 @@
 // RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
 // RUN:          --verify-diagnostics %s | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// arith.*itofp
+//===----------------------------------------------------------------------===//
+
 // CHECK-LABEL: func.func @sitofp_extsi_i16
 // CHECK-SAME:    (%[[ARG:.+]]: i16)
 // CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[ARG]] : i16 to f16
@@ -131,3 +135,103 @@ func.func @uitofp_extsi_i16(%a: i16) -> f16 {
   %f = arith.uitofp %b : i32 to f16
   return %f : f16
 }
+
+//===----------------------------------------------------------------------===//
+// Commute Extension over Vector Ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @extsi_over_extract_3xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract %[[ARG]][1] : vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.sitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @extsi_over_extract_3xi16(%a: vector<3xi16>) -> f16 {
+  %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
+  %c = vector.extract %b[1] : vector<3xi32>
+  %f = arith.sitofp %c : i32 to f16
+  return %f : f16
+}
+
+// CHECK-LABEL: func.func @extui_over_extract_3xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract %[[ARG]][1] : vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.uitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @extui_over_extract_3xi16(%a: vector<3xi16>) -> f16 {
+  %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+  %c = vector.extract %b[1] : vector<3xi32>
+  %f = arith.uitofp %c : i32 to f16
+  return %f : f16
+}
+
+// CHECK-LABEL: func.func @extsi_over_extractelement_3xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32)
+// CHECK-NEXT:    %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.sitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @extsi_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 {
+  %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
+  %c = vector.extractelement %b[%pos : i32] : vector<3xi32>
+  %f = arith.sitofp %c : i32 to f16
+  return %f : f16
+}
+
+// CHECK-LABEL: func.func @extui_over_extractelement_3xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32)
+// CHECK-NEXT:    %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.uitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @extui_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 {
+  %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+  %c = vector.extractelement %b[%pos : i32] : vector<3xi32>
+  %f = arith.uitofp %c : i32 to f16
+  return %f : f16
+}
+
+// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_1d
+// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[EXTR]] : vector<2xi16> to vector<2xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<2xi32>
+func.func @extsi_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> {
+  %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
+  %c = vector.extract_strided_slice %b
+   {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32>
+  return %c : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_extract_strided_slice_1d
+// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[EXTR]] : vector<2xi16> to vector<2xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<2xi32>
+func.func @extui_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> {
+  %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+  %c = vector.extract_strided_slice %b
+   {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32>
+  return %c : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_2d
+// CHECK-SAME:    (%[[ARG:.+]]: vector<2x3xi16>)
+// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<1x2xi32>
+func.func @extsi_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> {
+  %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
+  %c = vector.extract_strided_slice %b
+   {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
+  return %c : vector<1x2xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_extract_strided_slice_2d
+// CHECK-SAME:    (%[[ARG:.+]]: vector<2x3xi16>)
+// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<1x2xi32>
+func.func @extui_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> {
+  %b = arith.extui %a : vector<2x3xi16> to vector<2x3xi32>
+  %c = vector.extract_strided_slice %b
+   {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
+  return %c : vector<1x2xi32>
+}