From 09043a26c85dad1ae33a32c0927d467f622de157 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 28 Apr 2023 13:48:48 -0400 Subject: [PATCH] [mlir][arith] Add patterns to commute extension over vector extraction This moves zero/sign-extension ops closer to their use and exposes more narrowing optimization opportunities. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D149233 --- .../mlir/Dialect/Arith/Transforms/Passes.td | 1 + mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp | 83 ++++++++++++++++ mlir/test/Dialect/Arith/int-narrowing.mlir | 104 +++++++++++++++++++++ 3 files changed, 188 insertions(+) diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index e6fe468..50b7484 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -90,6 +90,7 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> { prefers the narrowest available integer bitwidths that are guaranteed to produce the same results. }]; + let dependentDialects = ["vector::VectorDialect"]; let options = [ ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned", "Integer bitwidths supported">, diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index e884a19..3401a9c 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -9,16 +9,19 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" #include #include @@ -144,6 +147,80 @@ using SIToFPPattern = IToFPPattern; using UIToFPPattern = IToFPPattern; //===----------------------------------------------------------------------===// +// Patterns to Commute Extension Ops +//===----------------------------------------------------------------------===// + +struct ExtensionOverExtract final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractOp op, + PatternRewriter &rewriter) const override { + Operation *def = op.getVector().getDefiningOp(); + if (!def) + return failure(); + + return TypeSwitch(def) + .Case([&](auto extOp) { + Value newExtract = rewriter.create( + op.getLoc(), extOp.getIn(), op.getPosition()); + rewriter.replaceOpWithNewOp(op, op.getType(), + newExtract); + return success(); + }) + .Default(failure()); + } +}; + +struct ExtensionOverExtractElement final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractElementOp op, + PatternRewriter &rewriter) const override { + Operation *def = op.getVector().getDefiningOp(); + if (!def) + return failure(); + + return TypeSwitch(def) + .Case([&](auto extOp) { + Value newExtract = rewriter.create( + op.getLoc(), extOp.getIn(), op.getPosition()); + rewriter.replaceOpWithNewOp(op, op.getType(), + newExtract); + return success(); + }) + .Default(failure()); + } +}; + +struct ExtensionOverExtractStridedSlice final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op, + PatternRewriter &rewriter) const override { + Operation *def = op.getVector().getDefiningOp(); + if (!def) + return failure(); + + return TypeSwitch(def) + .Case([&](auto extOp) { + VectorType origTy = op.getType(); + Type inElemTy = + cast(extOp.getIn().getType()).getElementType(); + VectorType extractTy = origTy.cloneWith(origTy.getShape(), inElemTy); + Value newExtract = rewriter.create( + op.getLoc(), extractTy, extOp.getIn(), op.getOffsets(), + op.getSizes(), op.getStrides()); + rewriter.replaceOpWithNewOp(op, op.getType(), + newExtract); + return success(); + }) + .Default(failure()); + } +}; + +//===----------------------------------------------------------------------===// // Pass Definitions //===----------------------------------------------------------------------===// @@ -169,6 +246,12 @@ struct ArithIntNarrowingPass final void populateArithIntNarrowingPatterns( RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) { + // Add commute patterns with a higher benefit. This is to expose more + // optimization opportunities to narrowing patterns. + patterns.add(patterns.getContext(), + PatternBenefit(2)); + patterns.add(patterns.getContext(), options); } diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir index 21d5ab7..f1290e5 100644 --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -1,6 +1,10 @@ // RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \ // RUN: --verify-diagnostics %s | FileCheck %s +//===----------------------------------------------------------------------===// +// arith.*itofp +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func.func @sitofp_extsi_i16 // CHECK-SAME: (%[[ARG:.+]]: i16) // CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : i16 to f16 @@ -131,3 +135,103 @@ func.func @uitofp_extsi_i16(%a: i16) -> f16 { %f = arith.uitofp %b : i32 to f16 return %f : f16 } + +//===----------------------------------------------------------------------===// +// Commute Extension over Vector Ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @extsi_over_extract_3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) +// CHECK-NEXT: %[[EXTR:.+]] = vector.extract %[[ARG]][1] : vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXTR]] : i16 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @extsi_over_extract_3xi16(%a: vector<3xi16>) -> f16 { + %b = arith.extsi %a : vector<3xi16> to vector<3xi32> + %c = vector.extract %b[1] : vector<3xi32> + %f = arith.sitofp %c : i32 to f16 + return %f : f16 +} + +// CHECK-LABEL: func.func @extui_over_extract_3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) +// CHECK-NEXT: %[[EXTR:.+]] = vector.extract %[[ARG]][1] : vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXTR]] : i16 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @extui_over_extract_3xi16(%a: vector<3xi16>) -> f16 { + %b = arith.extui %a : vector<3xi16> to vector<3xi32> + %c = vector.extract %b[1] : vector<3xi32> + %f = arith.uitofp %c : i32 to f16 + return %f : f16 +} + +// CHECK-LABEL: func.func @extsi_over_extractelement_3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32) +// CHECK-NEXT: %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXTR]] : i16 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @extsi_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 { + %b = arith.extsi %a : vector<3xi16> to vector<3xi32> + %c = vector.extractelement %b[%pos : i32] : vector<3xi32> + %f = arith.sitofp %c : i32 to f16 + return %f : f16 +} + +// CHECK-LABEL: func.func @extui_over_extractelement_3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32) +// CHECK-NEXT: %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXTR]] : i16 to f16 +// CHECK-NEXT: return %[[RET]] : f16 +func.func @extui_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 { + %b = arith.extui %a : vector<3xi16> to vector<3xi32> + %c = vector.extractelement %b[%pos : i32] : vector<3xi32> + %f = arith.uitofp %c : i32 to f16 + return %f : f16 +} + +// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_1d +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) +// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[EXTR]] : vector<2xi16> to vector<2xi32> +// CHECK-NEXT: return %[[RET]] : vector<2xi32> +func.func @extsi_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> { + %b = arith.extsi %a : vector<3xi16> to vector<3xi32> + %c = vector.extract_strided_slice %b + {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32> + return %c : vector<2xi32> +} + +// CHECK-LABEL: func.func @extui_over_extract_strided_slice_1d +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) +// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[EXTR]] : vector<2xi16> to vector<2xi32> +// CHECK-NEXT: return %[[RET]] : vector<2xi32> +func.func @extui_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> { + %b = arith.extui %a : vector<3xi16> to vector<3xi32> + %c = vector.extract_strided_slice %b + {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32> + return %c : vector<2xi32> +} + +// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_2d +// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>) +// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32> +// CHECK-NEXT: return %[[RET]] : vector<1x2xi32> +func.func @extsi_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> { + %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32> + %c = vector.extract_strided_slice %b + {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32> + return %c : vector<1x2xi32> +} + +// CHECK-LABEL: func.func @extui_over_extract_strided_slice_2d +// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>) +// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32> +// CHECK-NEXT: return %[[RET]] : vector<1x2xi32> +func.func @extui_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> { + %b = arith.extui %a : vector<2x3xi16> to vector<2x3xi32> + %c = vector.extract_strided_slice %b + {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32> + return %c : vector<1x2xi32> +} -- 2.7.4