From: Jakub Kuderski Date: Mon, 1 May 2023 18:31:20 +0000 (-0400) Subject: [mlir][arith] Add narrowing patterns to commute more vector ops X-Git-Tag: upstream/17.0.6~9880 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7f3b0e584513611bb1d804892eb269ae45d8e715;p=platform%2Fupstream%2Fllvm.git [mlir][arith] Add narrowing patterns to commute more vector ops This commutes the extension (`arith.extsi`, `arith.extui`) over the following vector ops: `vector.broadcast`, `vector.shape_cast`, `vector.transpose`, `vector.flat_transpose`. I focused on these as I saw them getting created by vector unroll patterns. Maybe except `vector.flat_transpose`. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D149534 --- diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index 9716462..0c7afd9 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -249,6 +249,26 @@ using UIToFPPattern = IToFPPattern; // Patterns to Commute Extension Ops //===----------------------------------------------------------------------===// +struct ExtensionOverBroadcast final : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(vector::BroadcastOp op, + PatternRewriter &rewriter) const override { + FailureOr ext = + ExtensionOp::from(op.getSource().getDefiningOp()); + if (failed(ext)) + return failure(); + + VectorType origTy = op.getResultVectorType(); + VectorType newTy = + origTy.cloneWith(origTy.getShape(), ext->getInElementType()); + Value newBroadcast = + rewriter.create(op.getLoc(), newTy, ext->getIn()); + ext->recreateAndReplace(rewriter, op, newBroadcast); + return success(); + } +}; + struct ExtensionOverExtract final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; @@ -421,6 +441,68 @@ struct ExtensionOverInsertStridedSlice final } }; +struct ExtensionOverShapeCast final : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + FailureOr ext = + ExtensionOp::from(op.getSource().getDefiningOp()); + if (failed(ext)) + return failure(); + + VectorType origTy = op.getResultVectorType(); + VectorType newTy = + origTy.cloneWith(origTy.getShape(), ext->getInElementType()); + Value newCast = + rewriter.create(op.getLoc(), newTy, ext->getIn()); + ext->recreateAndReplace(rewriter, op, newCast); + return success(); + } +}; + +struct ExtensionOverTranspose final : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + FailureOr ext = + ExtensionOp::from(op.getVector().getDefiningOp()); + if (failed(ext)) + return failure(); + + VectorType origTy = op.getResultVectorType(); + VectorType newTy = + origTy.cloneWith(origTy.getShape(), ext->getInElementType()); + Value newTranspose = rewriter.create( + op.getLoc(), newTy, ext->getIn(), op.getTransp()); + ext->recreateAndReplace(rewriter, op, newTranspose); + return success(); + } +}; + +struct ExtensionOverFlatTranspose final + : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(vector::FlatTransposeOp op, + PatternRewriter &rewriter) const override { + FailureOr ext = + ExtensionOp::from(op.getMatrix().getDefiningOp()); + if (failed(ext)) + return failure(); + + VectorType origTy = op.getType(); + VectorType newTy = + origTy.cloneWith(origTy.getShape(), ext->getInElementType()); + Value newTranspose = rewriter.create( + op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(), + op.getColumnsAttr()); + ext->recreateAndReplace(rewriter, op, newTranspose); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass Definitions //===----------------------------------------------------------------------===// @@ -449,9 +531,11 @@ void populateArithIntNarrowingPatterns( RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) { // Add commute patterns with a higher benefit. This is to expose more // optimization opportunities to narrowing patterns. - patterns.add( + patterns.add( patterns.getContext(), options, PatternBenefit(2)); patterns.add(patterns.getContext(), options); diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir index 6d5299c..675a52b 100644 --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -442,3 +442,91 @@ func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector< %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32> return %e : vector<2x3xi32> } + +// CHECK-LABEL: func.func @extsi_over_broadcast_3xi16 +// CHECK-SAME: (%[[ARG:.+]]: i16) +// CHECK-NEXT: %[[BCST:.+]] = vector.broadcast %[[ARG]] : i16 to vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[BCST]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extsi_over_broadcast_3xi16(%a: i16) -> vector<3xi32> { + %b = arith.extsi %a : i16 to i32 + %r = vector.broadcast %b : i32 to vector<3xi32> + return %r : vector<3xi32> +} + +// CHECK-LABEL: func.func @extui_over_broadcast_2x3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) +// CHECK-NEXT: %[[BCST:.+]] = vector.broadcast %[[ARG]] : vector<3xi16> to vector<2x3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[BCST]] : vector<2x3xi16> to vector<2x3xi32> +// CHECK-NEXT: return %[[RET]] : vector<2x3xi32> +func.func @extui_over_broadcast_2x3xi16(%a: vector<3xi16>) -> vector<2x3xi32> { + %b = arith.extui %a : vector<3xi16> to vector<3xi32> + %r = vector.broadcast %b : vector<3xi32> to vector<2x3xi32> + return %r : vector<2x3xi32> +} + +// CHECK-LABEL: func.func @extsi_over_shape_cast_2x3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>) +// CHECK-NEXT: %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<2x3xi16> to vector<3x2xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[CAST]] : vector<3x2xi16> to vector<3x2xi32> +// CHECK-NEXT: return %[[RET]] : vector<3x2xi32> +func.func @extsi_over_shape_cast_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> { + %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32> + %r = vector.shape_cast %b : vector<2x3xi32> to vector<3x2xi32> + return %r : vector<3x2xi32> +} + +// CHECK-LABEL: func.func @extui_over_shape_cast_5x2x3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<5x2x3xi16>) +// CHECK-NEXT: %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<5x2x3xi16> to vector<2x3x5xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[CAST]] : vector<2x3x5xi16> to vector<2x3x5xi32> +// CHECK-NEXT: return %[[RET]] : vector<2x3x5xi32> +func.func @extui_over_shape_cast_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> { + %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32> + %r = vector.shape_cast %b : vector<5x2x3xi32> to vector<2x3x5xi32> + return %r : vector<2x3x5xi32> +} + +// CHECK-LABEL: func.func @extsi_over_transpose_2x3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>) +// CHECK-NEXT: %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 0] : vector<2x3xi16> to vector<3x2xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[TRAN]] : vector<3x2xi16> to vector<3x2xi32> +// CHECK-NEXT: return %[[RET]] : vector<3x2xi32> +func.func @extsi_over_transpose_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> { + %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32> + %r = vector.transpose %b, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %r : vector<3x2xi32> +} + +// CHECK-LABEL: func.func @extui_over_transpose_5x2x3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<5x2x3xi16>) +// CHECK-NEXT: %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 2, 0] : vector<5x2x3xi16> to vector<2x3x5xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[TRAN]] : vector<2x3x5xi16> to vector<2x3x5xi32> +// CHECK-NEXT: return %[[RET]] : vector<2x3x5xi32> +func.func @extui_over_transpose_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> { + %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32> + %r = vector.transpose %b, [1, 2, 0] : vector<5x2x3xi32> to vector<2x3x5xi32> + return %r : vector<2x3x5xi32> +} + +// CHECK-LABEL: func.func @extsi_over_flat_transpose_16xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<16xi16>) +// CHECK-NEXT: %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 4 : i32, rows = 4 : i32} : vector<16xi16> -> vector<16xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[TRAN]] : vector<16xi16> to vector<16xi32> +// CHECK-NEXT: return %[[RET]] : vector<16xi32> +func.func @extsi_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> { + %b = arith.extsi %a : vector<16xi16> to vector<16xi32> + %r = vector.flat_transpose %b {columns = 4 : i32, rows = 4 : i32} : vector<16xi32> -> vector<16xi32> + return %r : vector<16xi32> +} + +// CHECK-LABEL: func.func @extui_over_flat_transpose_16xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<16xi16>) +// CHECK-NEXT: %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 8 : i32, rows = 2 : i32} : vector<16xi16> -> vector<16xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[TRAN]] : vector<16xi16> to vector<16xi32> +// CHECK-NEXT: return %[[RET]] : vector<16xi32> +func.func @extui_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> { + %b = arith.extui %a : vector<16xi16> to vector<16xi32> + %r = vector.flat_transpose %b {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32> + return %r : vector<16xi32> +}