From 9701c5abd669125da4bc6dc538eacb6f7a39dbd1 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Tue, 2 May 2023 10:48:01 -0400 Subject: [PATCH] [mlir][arith] Add narrowing patterns for `max*i` and `min*i` Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D149583 --- mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp | 26 ++++- mlir/test/Dialect/Arith/int-narrowing.mlir | 128 +++++++++++++++++++++ 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index cb6e437..344caff 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -368,6 +368,29 @@ struct DivUIPattern final : BinaryOpNarrowingPattern { }; //===----------------------------------------------------------------------===// +// Min/Max Patterns +//===----------------------------------------------------------------------===// + +template +struct MinMaxPattern final : BinaryOpNarrowingPattern { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == Kind; + } + + // Min/max returns one of the arguments and does not require any extra result + // bits. + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits; + } +}; +using MaxSIPattern = MinMaxPattern; +using MaxUIPattern = MinMaxPattern; +using MinSIPattern = MinMaxPattern; +using MinUIPattern = MinMaxPattern; + +//===----------------------------------------------------------------------===// // *IToFPOp Patterns //===----------------------------------------------------------------------===// @@ -690,7 +713,8 @@ void populateArithIntNarrowingPatterns( patterns.getContext(), options, PatternBenefit(2)); patterns.add( + DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern, + MinUIPattern, SIToFPPattern, UIToFPPattern>( patterns.getContext(), options); } diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir index 4b155ad..484f601 100644 --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -474,6 +474,134 @@ func.func @uitofp_extsi_i16(%a: i16) -> f16 { } //===----------------------------------------------------------------------===// +// arith.maxsi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @maxsi_extsi_i8 +// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) +// CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[LHS]], %[[RHS]] : i8 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MAX]] : i8 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @maxsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.maxsi %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.maxsi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @maxsi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[MAX]] : i32 +func.func @maxsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.maxsi %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// +// arith.maxui +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @maxui_extui_i8 +// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) +// CHECK-NEXT: %[[MAX:.+]] = arith.maxui %[[LHS]], %[[RHS]] : i8 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MAX]] : i8 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @maxui_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.maxui %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.maxsi` ops with zero-extended +// arguments. +// +// CHECK-LABEL: func.func @maxui_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[MAX:.+]] = arith.maxui %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[MAX]] : i32 +func.func @maxui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.maxui %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// +// arith.minsi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @minsi_extsi_i8 +// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) +// CHECK-NEXT: %[[min:.+]] = arith.minsi %[[LHS]], %[[RHS]] : i8 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[min]] : i8 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @minsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.minsi %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.minsi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @minsi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[min:.+]] = arith.minsi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[min]] : i32 +func.func @minsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.minsi %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// +// arith.minui +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @minui_extui_i8 +// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) +// CHECK-NEXT: %[[min:.+]] = arith.minui %[[LHS]], %[[RHS]] : i8 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[min]] : i8 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @minui_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.minui %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.minsi` ops with zero-extended +// arguments. +// +// CHECK-LABEL: func.func @minui_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[min:.+]] = arith.minui %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[min]] : i32 +func.func @minui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.minui %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// // Commute Extension over Vector Ops //===----------------------------------------------------------------------===// -- 2.7.4