From 8cdb4f9f66439729ff7a6781fd7d765f9e2b44c4 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Mon, 9 Jan 2023 17:05:54 +0000 Subject: [PATCH] [mlir][Index] Add index.mins and index.minu Signed and unsigned minimum operations were missing from the Index dialect and are needed to test integer range inference. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D141299 --- mlir/include/mlir/Dialect/Index/IR/IndexOps.td | 39 +++++++++++++++++++++ mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp | 6 ++++ mlir/lib/Dialect/Index/IR/IndexOps.cpp | 20 +++++++++++ .../test/Conversion/IndexToLLVM/index-to-llvm.mlir | 18 ++++++---- mlir/test/Dialect/Index/index-canonicalize.mlir | 40 ++++++++++++++++++++++ mlir/test/Dialect/Index/index-ops.mlir | 16 +++++---- 6 files changed, 126 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td index 9bed038..76008a1 100644 --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -282,6 +282,45 @@ def Index_MaxUOp : IndexBinaryOp<"maxu"> { } //===----------------------------------------------------------------------===// +// MinSOp +//===----------------------------------------------------------------------===// + +def Index_MinSOp : IndexBinaryOp<"mins"> { + let summary = "index signed minimum"; + let description = [{ + The `index.mins` operation takes two index values and computes their signed + minimum value. Treats the leading bit as the sign, i.e. `min(-2, 6) = -2`. + + Example: + + ```mlir + // c = min(a, b) + %c = index.mins %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MinUOp +//===----------------------------------------------------------------------===// + +def Index_MinUOp : IndexBinaryOp<"minu"> { + let summary = "index unsigned minimum"; + let description = [{ + The `index.minu` operation takes two index values and computes their + unsigned minimum value. Treats the leading bit as the most significant, i.e. + `min(15, 6) = 6` or `min(-2, 6) = 6`. + + Example: + + ```mlir + // c = min(a, b) + %c = index.minu %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// // ShlOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp index 9fa2e53..2b17342 100644 --- a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp +++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp @@ -268,6 +268,10 @@ using ConvertIndexMaxS = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexMaxU = mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexMinS = + mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexMinU = + mlir::OneToOneConvertToLLVMPattern; using ConvertIndexShl = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexShrS = mlir::OneToOneConvertToLLVMPattern; @@ -298,6 +302,8 @@ void index::populateIndexToLLVMConversionPatterns( ConvertIndexRemU, ConvertIndexMaxS, ConvertIndexMaxU, + ConvertIndexMinS, + ConvertIndexMinU, ConvertIndexShl, ConvertIndexShrS, ConvertIndexShrU, diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp index 2eadabb..dee6025 100644 --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -287,6 +287,26 @@ OpFoldResult MaxUOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// +// MinSOp +//===----------------------------------------------------------------------===// + +OpFoldResult MinSOp::fold(ArrayRef operands) { + return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) { + return lhs.slt(rhs) ? lhs : rhs; + }); +} + +//===----------------------------------------------------------------------===// +// MinUOp +//===----------------------------------------------------------------------===// + +OpFoldResult MinUOp::fold(ArrayRef operands) { + return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) { + return lhs.ult(rhs) ? lhs : rhs; + }); +} + +//===----------------------------------------------------------------------===// // ShlOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir index 44aea80..8e4e37a 100644 --- a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir +++ b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir @@ -22,20 +22,24 @@ func.func @trivial_ops(%a: index, %b: index) { %7 = index.maxs %a, %b // CHECK: llvm.intr.umax %8 = index.maxu %a, %b + // CHECK: llvm.intr.smin + %9 = index.mins %a, %b + // CHECK: llvm.intr.umin + %10 = index.minu %a, %b // CHECK: llvm.shl - %9 = index.shl %a, %b + %11 = index.shl %a, %b // CHECK: llvm.ashr - %10 = index.shrs %a, %b + %12 = index.shrs %a, %b // CHECK: llvm.lshr - %11 = index.shru %a, %b + %13 = index.shru %a, %b // CHECK: llvm.add - %12 = index.add %a, %b + %14 = index.add %a, %b // CHECK: llvm.or - %13 = index.or %a, %b + %15 = index.or %a, %b // CHECK: llvm.xor - %14 = index.xor %a, %b + %16 = index.xor %a, %b // CHECK: llvm.mlir.constant(true - %15 = index.bool.constant true + %17 = index.bool.constant true return } diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir index d525ecd..c9b3079 100644 --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -279,6 +279,46 @@ func.func @maxu() -> index { return %0 : index } +// CHECK-LABEL: @mins +func.func @mins() -> index { + %lhs = index.constant -4 + %rhs = index.constant 2 + // CHECK: %[[A:.*]] = index.constant -4 + %0 = index.mins %lhs, %rhs + // CHECK: return %[[A]] + return %0 : index +} + +// CHECK-LABEL: @mins_nofold +func.func @mins_nofold() -> index { + %lhs = index.constant 1 + %rhs = index.constant 0x100000000 + // 32-bit result differs from 64-bit. + // CHECK: index.mins + %0 = index.mins %lhs, %rhs + return %0 : index +} + +// CHECK-LABEL: @mins_nofold_2 +func.func @mins_nofold_2() -> index { + %lhs = index.constant 0x7fffffff + %rhs = index.constant 0x80000000 + // 32-bit result differs from 64-bit. + // CHECK: index.mins + %0 = index.mins %lhs, %rhs + return %0 : index +} + +// CHECK-LABEL: @minu +func.func @minu() -> index { + %lhs = index.constant -1 + %rhs = index.constant 1 + // CHECK: %[[A:.*]] = index.constant 1 + %0 = index.minu %lhs, %rhs + // CHECK: return %[[A]] + return %0 : index +} + // CHECK-LABEL: @shl func.func @shl() -> index { %lhs = index.constant 128 diff --git a/mlir/test/Dialect/Index/index-ops.mlir b/mlir/test/Dialect/Index/index-ops.mlir index e79686b..ee55778 100644 --- a/mlir/test/Dialect/Index/index-ops.mlir +++ b/mlir/test/Dialect/Index/index-ops.mlir @@ -27,18 +27,22 @@ func.func @binary_ops(%a: index, %b: index) { %10 = index.maxs %a, %b // CHECK-NEXT: index.maxu %[[A]], %[[B]] %11 = index.maxu %a, %b + // CHECK-NEXT: index.mins %[[A]], %[[B]] + %12 = index.mins %a, %b + // CHECK-NEXT: index.minu %[[A]], %[[B]] + %13 = index.minu %a, %b // CHECK-NEXT: index.shl %[[A]], %[[B]] - %12 = index.shl %a, %b + %14 = index.shl %a, %b // CHECK-NEXT: index.shrs %[[A]], %[[B]] - %13 = index.shrs %a, %b + %15 = index.shrs %a, %b // CHECK-NEXT: index.shru %[[A]], %[[B]] - %14 = index.shru %a, %b + %16 = index.shru %a, %b // CHECK-NEXT: index.and %[[A]], %[[B]] - %15 = index.and %a, %b + %17 = index.and %a, %b // CHECK-NEXT: index.or %[[A]], %[[B]] - %16 = index.or %a, %b + %18 = index.or %a, %b // CHECK-NEXT: index.xor %[[A]], %[[B]] - %17 = index.xor %a, %b + %19 = index.xor %a, %b return } -- 2.7.4