From: Krzysztof Drewniak Date: Mon, 9 Jan 2023 17:09:17 +0000 (+0000) Subject: [mlir][Index] Implement InferIntRangeInterface, re-land X-Git-Tag: upstream/17.0.6~20209 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5af9d16dae71f2c2087ba88c5fc06893e6aecfe9;p=platform%2Fupstream%2Fllvm.git [mlir][Index] Implement InferIntRangeInterface, re-land Re-land D140899 to fix a missing dependency in the index dialect's CMakeLists.txt. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D142147 --- diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h index 85a0549..d8debfb 100644 --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h @@ -13,6 +13,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td index 76008a1..8fbccc4 100644 --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/Index/IR/IndexDialect.td" include "mlir/Dialect/Index/IR/IndexEnums.td" include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" @@ -23,7 +24,8 @@ include "mlir/IR/OpBase.td" /// Base class for Index dialect operations. class IndexOp traits = []> - : Op; + : Op] # traits>; //===----------------------------------------------------------------------===// // IndexBinaryOp diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h new file mode 100644 index 0000000..7ee059c --- /dev/null +++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h @@ -0,0 +1,126 @@ +//===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares implementations of range inference for operations that are +// common to both the `arith` and `index` dialects to facilitate reuse. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H +#define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { +namespace intrange { +/// Function that performs inference on an array of `ConstantIntRanges`, +/// abstracted away here to permit writing the function that handles both +/// 64- and 32-bit index types. +using InferRangeFn = + function_ref)>; + +static constexpr unsigned indexMinWidth = 32; +static constexpr unsigned indexMaxWidth = 64; + +enum class CmpMode : uint32_t { Both, Signed, Unsigned }; + +/// Compute `inferFn` on `ranges`, whose size should be the index storage +/// bitwidth. Then, compute the function on `argRanges` again after truncating +/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is +/// equal to the 32-bit result, use it (to preserve compatibility with folders +/// and inference precision), and take the union of the results otherwise. +/// +/// The `mode` argument specifies if the unsigned, signed, or both results of +/// the inference computation should be used when comparing the results. +ConstantIntRanges inferIndexOp(InferRangeFn inferFn, + ArrayRef argRanges, + CmpMode mode); + +/// Independently zero-extend the unsigned values and sign-extend the signed +/// values in `range` to `destWidth` bits, returning the resulting range. +ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth); + +/// Use the unsigned values in `range` to zero-extend it to `destWidth`. +ConstantIntRanges extUIRange(const ConstantIntRanges &range, + unsigned destWidth); + +/// Use the signed values in `range` to sign-extend it to `destWidth`. +ConstantIntRanges extSIRange(const ConstantIntRanges &range, + unsigned destWidth); + +/// Truncate `range` to `destWidth` bits, taking care to handle cases such as +/// the truncation of [255, 256] to i8 not being a uniform range. +ConstantIntRanges truncRange(const ConstantIntRanges &range, + unsigned destWidth); + +ConstantIntRanges inferAdd(ArrayRef argRanges); + +ConstantIntRanges inferSub(ArrayRef argRanges); + +ConstantIntRanges inferMul(ArrayRef argRanges); + +ConstantIntRanges inferDivS(ArrayRef argRanges); + +ConstantIntRanges inferDivU(ArrayRef argRanges); + +ConstantIntRanges inferCeilDivS(ArrayRef argRanges); + +ConstantIntRanges inferCeilDivU(ArrayRef argRanges); + +ConstantIntRanges inferFloorDivS(ArrayRef argRanges); + +ConstantIntRanges inferRemS(ArrayRef argRanges); + +ConstantIntRanges inferRemU(ArrayRef argRanges); + +ConstantIntRanges inferMaxS(ArrayRef argRanges); + +ConstantIntRanges inferMaxU(ArrayRef argRanges); + +ConstantIntRanges inferMinS(ArrayRef argRanges); + +ConstantIntRanges inferMinU(ArrayRef argRanges); + +ConstantIntRanges inferAnd(ArrayRef argRanges); + +ConstantIntRanges inferOr(ArrayRef argRanges); + +ConstantIntRanges inferXor(ArrayRef argRanges); + +ConstantIntRanges inferShl(ArrayRef argRanges); + +ConstantIntRanges inferShrS(ArrayRef argRanges); + +ConstantIntRanges inferShrU(ArrayRef argRanges); + +/// Copy of the enum from `arith` and `index` to allow the common integer range +/// infrastructure to not depend on either dialect. +enum class CmpPredicate : uint64_t { + eq, + ne, + slt, + sle, + sgt, + sge, + ult, + ule, + ugt, + uge, +}; + +/// Returns a boolean value if `pred` is statically true or false for +/// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the +/// value of the predicate cannot be determined. +Optional evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs); + +} // namespace intrange +} // namespace mlir + +#endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt index 0de17bb..ffbe801 100644 --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArithDialect LINK_LIBS PUBLIC MLIRDialect + MLIRInferIntRangeCommon MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp index 10d6ef2..971477f 100644 --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "llvm/Support/Debug.h" #include @@ -16,48 +17,7 @@ using namespace mlir; using namespace mlir::arith; - -/// Function that evaluates the result of doing something on arithmetic -/// constants and returns std::nullopt on overflow. -using ConstArithFn = - function_ref(const APInt &, const APInt &)>; - -/// Return the maxmially wide signed or unsigned range for a given bitwidth. - -/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, -/// If either computation overflows, make the result unbounded. -static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, - const APInt &minRight, - const APInt &maxLeft, - const APInt &maxRight, bool isSigned) { - std::optional maybeMin = op(minLeft, minRight); - std::optional maybeMax = op(maxLeft, maxRight); - if (maybeMin && maybeMax) - return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); - return ConstantIntRanges::maxRange(minLeft.getBitWidth()); -} - -/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, -/// ignoring unbounded values. Returns the maximal range if `op` overflows. -static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, - ArrayRef rhs, bool isSigned) { - unsigned width = lhs[0].getBitWidth(); - APInt min = - isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); - APInt max = - isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); - for (const APInt &left : lhs) { - for (const APInt &right : rhs) { - std::optional maybeThisResult = op(left, right); - if (!maybeThisResult) - return ConstantIntRanges::maxRange(width); - APInt result = std::move(*maybeThisResult); - min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; - max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; - } - } - return ConstantIntRanges::range(min, max, isSigned); -} +using namespace mlir::intrange; //===----------------------------------------------------------------------===// // ConstantOp @@ -78,25 +38,7 @@ void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, void arith::AddIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - ConstArithFn uadd = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.uadd_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn sadd = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.sadd_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - - ConstantIntRanges urange = computeBoundsBy( - uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); - ConstantIntRanges srange = computeBoundsBy( - sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferAdd(argRanges)); } //===----------------------------------------------------------------------===// @@ -105,25 +47,7 @@ void arith::AddIOp::inferResultRanges(ArrayRef argRanges, void arith::SubIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn usub = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.usub_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn ssub = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.ssub_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstantIntRanges urange = computeBoundsBy( - usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); - ConstantIntRanges srange = computeBoundsBy( - ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferSub(argRanges)); } //===----------------------------------------------------------------------===// @@ -132,96 +56,25 @@ void arith::SubIOp::inferResultRanges(ArrayRef argRanges, void arith::MulIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn umul = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.umul_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn smul = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.smul_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - - ConstantIntRanges urange = - minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); - ConstantIntRanges srange = - minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, - /*isSigned=*/true); - - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferMul(argRanges)); } //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// -/// Fix up division results (ex. for ceiling and floor), returning an APInt -/// if there has been no overflow -using DivisionFixupFn = function_ref( - const APInt &lhs, const APInt &rhs, const APInt &result)>; - -static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), - &rhsMax = rhs.umax(); - - if (!rhsMin.isZero()) { - auto udiv = [&fixup](const APInt &a, - const APInt &b) -> std::optional { - return fixup(a, b, a.udiv(b)); - }; - return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/false); - } - // Otherwise, it's possible we might divide by 0. - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); -} - void arith::DivUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferDivUIRange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; })); + setResultRange(getResult(), inferDivU(argRanges)); } //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { - const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), - &rhsMax = rhs.smax(); - bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); - - if (canDivide) { - auto sdiv = [&fixup](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.sdiv_ov(b, overflowed); - return overflowed ? std::optional() : fixup(a, b, result); - }; - return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/true); - } - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); -} - void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferDivSIRange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; })); + setResultRange(getResult(), inferDivS(argRanges)); } //===----------------------------------------------------------------------===// @@ -230,20 +83,7 @@ void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, void arith::CeilDivUIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivUIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.urem(rhs).isZero()) { - bool overflowed = false; - APInt corrected = - result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix)); + setResultRange(getResult(), inferCeilDivU(argRanges)); } //===----------------------------------------------------------------------===// @@ -252,20 +92,7 @@ void arith::CeilDivUIOp::inferResultRanges( void arith::CeilDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivSIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { - bool overflowed = false; - APInt corrected = - result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix)); + setResultRange(getResult(), inferCeilDivS(argRanges)); } //===----------------------------------------------------------------------===// @@ -274,20 +101,7 @@ void arith::CeilDivSIOp::inferResultRanges( void arith::FloorDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn floorDivSIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { - bool overflowed = false; - APInt corrected = - result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix)); + return setResultRange(getResult(), inferFloorDivS(argRanges)); } //===----------------------------------------------------------------------===// @@ -296,29 +110,7 @@ void arith::FloorDivSIOp::inferResultRanges( void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); - - unsigned width = rhsMin.getBitWidth(); - APInt umin = APInt::getZero(width); - APInt umax = APInt::getMaxValue(width); - - if (!rhsMin.isZero()) { - umax = rhsMax - 1; - // Special case: sweeping out a contiguous range in N/[modulus] - if (rhsMin == rhsMax) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); - if ((lhsMax - lhsMin).ult(rhsMax)) { - APInt minRem = lhsMin.urem(rhsMax); - APInt maxRem = lhsMax.urem(rhsMax); - if (minRem.ule(maxRem)) { - umin = minRem; - umax = maxRem; - } - } - } - } - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + setResultRange(getResult(), inferRemU(argRanges)); } //===----------------------------------------------------------------------===// @@ -327,67 +119,16 @@ void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, void arith::RemSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), - &rhsMax = rhs.smax(); - - unsigned width = rhsMax.getBitWidth(); - APInt smin = APInt::getSignedMinValue(width); - APInt smax = APInt::getSignedMaxValue(width); - // No bounds if zero could be a divisor. - bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); - if (canBound) { - APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); - bool canNegativeDividend = lhsMin.isNegative(); - bool canPositiveDividend = lhsMax.isStrictlyPositive(); - APInt zero = APInt::getZero(maxDivisor.getBitWidth()); - APInt maxPositiveResult = maxDivisor - 1; - APInt minNegativeResult = -maxPositiveResult; - smin = canNegativeDividend ? minNegativeResult : zero; - smax = canPositiveDividend ? maxPositiveResult : zero; - // Special case: sweeping out a contiguous range in N/[modulus]. - if (rhsMin == rhsMax) { - if ((lhsMax - lhsMin).ult(maxDivisor)) { - APInt minRem = lhsMin.srem(maxDivisor); - APInt maxRem = lhsMax.srem(maxDivisor); - if (minRem.sle(maxRem)) { - smin = minRem; - smax = maxRem; - } - } - } - } - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + setResultRange(getResult(), inferRemS(argRanges)); } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// -/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, -/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits -/// that both bonuds have in common. This gives us a consertive approximation -/// for what values can be passed to bitwise operations. -static std::tuple -widenBitwiseBounds(const ConstantIntRanges &bound) { - APInt leftVal = bound.umin(), rightVal = bound.umax(); - unsigned bitwidth = leftVal.getBitWidth(); - unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); - leftVal.clearLowBits(differingBits); - rightVal.setLowBits(differingBits); - return std::make_tuple(std::move(leftVal), std::move(rightVal)); -} - void arith::AndIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto andi = [](const APInt &a, const APInt &b) -> std::optional { - return a & b; - }; - setResultRange(getResult(), - minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + setResultRange(getResult(), inferAnd(argRanges)); } //===----------------------------------------------------------------------===// @@ -396,14 +137,7 @@ void arith::AndIOp::inferResultRanges(ArrayRef argRanges, void arith::OrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto ori = [](const APInt &a, const APInt &b) -> std::optional { - return a | b; - }; - setResultRange(getResult(), - minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + setResultRange(getResult(), inferOr(argRanges)); } //===----------------------------------------------------------------------===// @@ -412,14 +146,7 @@ void arith::OrIOp::inferResultRanges(ArrayRef argRanges, void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto xori = [](const APInt &a, const APInt &b) -> std::optional { - return a ^ b; - }; - setResultRange(getResult(), - minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + setResultRange(getResult(), inferXor(argRanges)); } //===----------------------------------------------------------------------===// @@ -428,11 +155,7 @@ void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); - const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + setResultRange(getResult(), inferMaxS(argRanges)); } //===----------------------------------------------------------------------===// @@ -441,11 +164,7 @@ void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); - const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + setResultRange(getResult(), inferMaxU(argRanges)); } //===----------------------------------------------------------------------===// @@ -454,11 +173,7 @@ void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); - const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + setResultRange(getResult(), inferMinS(argRanges)); } //===----------------------------------------------------------------------===// @@ -467,94 +182,40 @@ void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); - const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + setResultRange(getResult(), inferMinU(argRanges)); } //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges extUIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - APInt umin = range.umin().zext(destWidth); - APInt umax = range.umax().zext(destWidth); - return ConstantIntRanges::fromUnsigned(umin, umax); -} - void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), extUIRange(argRanges[0], destType)); + unsigned destWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges extSIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - APInt smin = range.smin().sext(destWidth); - APInt smax = range.smax().sext(destWidth); - return ConstantIntRanges::fromSigned(smin, smax); -} - void arith::ExtSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), extSIRange(argRanges[0], destType)); + unsigned destWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges truncIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], - // the range of the resulting value is not contiguous ind includes 0. - // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], - // but you can't truncate [255, 257] similarly. - bool hasUnsignedRollover = - range.umin().lshr(destWidth) != range.umax().lshr(destWidth); - APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) - : range.umin().trunc(destWidth); - APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) - : range.umax().trunc(destWidth); - - // Signed post-truncation rollover will not occur when either: - // - The high parts of the min and max, plus the sign bit, are the same - // - The high halves + sign bit of the min and max are either all 1s or all 0s - // and you won't create a [positive, negative] range by truncating. - // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 - // but not [255, 257]_i16 to a range of i8s. You can also truncate - // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. - // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) - // will truncate to 0x7e, which is greater than 0 - APInt sminHighPart = range.smin().ashr(destWidth - 1); - APInt smaxHighPart = range.smax().ashr(destWidth - 1); - bool hasSignedOverflow = - (sminHighPart != smaxHighPart) && - !(sminHighPart.isAllOnes() && - (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && - !(sminHighPart.isZero() && smaxHighPart.isZero()); - APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) - : range.smin().trunc(destWidth); - APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) - : range.smax().trunc(destWidth); - return {umin, umax, smin, smax}; -} - void arith::TruncIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), truncIRange(argRanges[0], destType)); + unsigned destWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + setResultRange(getResult(), truncRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// @@ -569,9 +230,9 @@ void arith::IndexCastOp::inferResultRanges( unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) - setResultRange(getResult(), extSIRange(argRanges[0], destType)); + setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); else if (srcWidth > destWidth) - setResultRange(getResult(), truncIRange(argRanges[0], destType)); + setResultRange(getResult(), truncRange(argRanges[0], destWidth)); else setResultRange(getResult(), argRanges[0]); } @@ -588,9 +249,9 @@ void arith::IndexCastUIOp::inferResultRanges( unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) - setResultRange(getResult(), extUIRange(argRanges[0], destType)); + setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); else if (srcWidth > destWidth) - setResultRange(getResult(), truncIRange(argRanges[0], destType)); + setResultRange(getResult(), truncRange(argRanges[0], destWidth)); else setResultRange(getResult(), argRanges[0]); } @@ -599,51 +260,19 @@ void arith::IndexCastUIOp::inferResultRanges( // CmpIOp //===----------------------------------------------------------------------===// -bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs) { - switch (pred) { - case arith::CmpIPredicate::sle: - case arith::CmpIPredicate::slt: - return (applyCmpPredicate(pred, lhs.smax(), rhs.smin())); - case arith::CmpIPredicate::ule: - case arith::CmpIPredicate::ult: - return applyCmpPredicate(pred, lhs.umax(), rhs.umin()); - case arith::CmpIPredicate::sge: - case arith::CmpIPredicate::sgt: - return applyCmpPredicate(pred, lhs.smin(), rhs.smax()); - case arith::CmpIPredicate::uge: - case arith::CmpIPredicate::ugt: - return applyCmpPredicate(pred, lhs.umin(), rhs.umax()); - case arith::CmpIPredicate::eq: { - std::optional lhsConst = lhs.getConstantValue(); - std::optional rhsConst = rhs.getConstantValue(); - return lhsConst && rhsConst && lhsConst == rhsConst; - } - case arith::CmpIPredicate::ne: { - // While equality requires that there is an interpration of the preceeding - // computations that produces equal constants, whether that be signed or - // unsigned, statically determining inequality requires that neither - // interpretation produce potentially overlapping ranges. - bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || - isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); - bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || - isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); - return sne && une; - } - } - return false; -} - void arith::CmpIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - arith::CmpIPredicate pred = getPredicate(); + arith::CmpIPredicate arithPred = getPredicate(); + intrange::CmpPredicate pred = static_cast(arithPred); const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; APInt min = APInt::getZero(1); APInt max = APInt::getAllOnesValue(1); - if (isStaticallyTrue(pred, lhs, rhs)) + + Optional truthValue = intrange::evaluatePred(pred, lhs, rhs); + if (truthValue.has_value() && *truthValue) min = max; - else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) + else if (truthValue.has_value() && !(*truthValue)) max = min; setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); @@ -673,18 +302,7 @@ void arith::SelectOp::inferResultRanges(ArrayRef argRanges, void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - ConstArithFn shl = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.shl(r); - }; - ConstantIntRanges urange = - minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); - ConstantIntRanges srange = - minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferShl(argRanges)); } //===----------------------------------------------------------------------===// @@ -693,15 +311,7 @@ void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn lshr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); - }; - setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()}, - {rhs.umin(), rhs.umax()}, - /*isSigned=*/false)); + setResultRange(getResult(), inferShrU(argRanges)); } //===----------------------------------------------------------------------===// @@ -710,14 +320,5 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, void arith::ShRSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn ashr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); - }; - - setResultRange(getResult(), - minMaxBy(ashr, {lhs.smin(), lhs.smax()}, - {rhs.umin(), rhs.umax()}, /*isSigned=*/true)); + setResultRange(getResult(), inferShrS(argRanges)); } diff --git a/mlir/lib/Dialect/Index/IR/CMakeLists.txt b/mlir/lib/Dialect/Index/IR/CMakeLists.txt index 53321f1..fce47d2 100644 --- a/mlir/lib/Dialect/Index/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Index/IR/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIndexDialect IndexAttrs.cpp IndexDialect.cpp IndexOps.cpp + InferIntRangeInterfaceImpls.cpp DEPENDS MLIRIndexOpsIncGen @@ -10,6 +11,8 @@ add_mlir_dialect_library(MLIRIndexDialect MLIRDialect MLIRIR MLIRCastInterfaces + MLIRInferIntRangeCommon + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp new file mode 100644 index 0000000..6daa764 --- /dev/null +++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp @@ -0,0 +1,252 @@ +//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-range-analysis" + +using namespace mlir; +using namespace mlir::index; +using namespace mlir::intrange; + +//===----------------------------------------------------------------------===// +// Constants +//===----------------------------------------------------------------------===// + +void ConstantOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const APInt &value = getValue(); + setResultRange(getResult(), ConstantIntRanges::constant(value)); +} + +void BoolConstantOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + bool value = getValue(); + APInt asInt(/*numBits=*/1, value); + setResultRange(getResult(), ConstantIntRanges::constant(asInt)); +} + +//===----------------------------------------------------------------------===// +// Arithmec operations. All of these operations will have their results inferred +// using both the 64-bit values and truncated 32-bit values of their inputs, +// with the results being the union of those inferences, except where the +// truncation of the 64-bit result is equal to the 32-bit result (at which time +// we take the 64-bit result). +//===----------------------------------------------------------------------===// + +void AddOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both)); +} + +void SubOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both)); +} + +void MulOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both)); +} + +void DivUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned)); +} + +void DivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferDivS, argRanges, CmpMode::Signed)); +} + +void CeilDivUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned)); +} + +void CeilDivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed)); +} + +void FloorDivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + return setResultRange( + getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed)); +} + +void RemSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferRemS, argRanges, CmpMode::Signed)); +} + +void RemUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned)); +} + +void MaxSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMaxS, argRanges, CmpMode::Signed)); +} + +void MaxUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned)); +} + +void MinSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMinS, argRanges, CmpMode::Signed)); +} + +void MinUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned)); +} + +void ShlOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both)); +} + +void ShrSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferShrS, argRanges, CmpMode::Signed)); +} + +void ShrUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned)); +} + +void AndOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned)); +} + +void OrOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferOr, argRanges, CmpMode::Unsigned)); +} + +void XOrOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferXor, argRanges, CmpMode::Unsigned)); +} + +//===----------------------------------------------------------------------===// +// Casts +//===----------------------------------------------------------------------===// + +static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, + unsigned srcWidth, unsigned destWidth, + bool isSigned) { + if (srcWidth < destWidth) + return isSigned ? extSIRange(range, destWidth) + : extUIRange(range, destWidth); + if (srcWidth > destWidth) + return truncRange(range, destWidth); + return range; +} + +// When casting to `index`, we will take the union of the possible fixed-width +// casts. +static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, + Type sourceType, Type destType, + bool isSigned) { + unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + if (sourceType.isIndex()) + return makeLikeDest(range, srcWidth, destWidth, isSigned); + // We are casting to indexs, so use the union of the 32-bit and 64-bit casts + ConstantIntRanges storageRange = + makeLikeDest(range, srcWidth, destWidth, isSigned); + ConstantIntRanges minWidthRange = + makeLikeDest(range, srcWidth, indexMinWidth, isSigned); + ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth); + ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt); + return ret; +} + +void CastSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, + /*isSigned=*/true)); +} + +void CastUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, + /*isSigned=*/false)); +} + +//===----------------------------------------------------------------------===// +// CmpOp +//===----------------------------------------------------------------------===// + +void CmpOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + index::IndexCmpPredicate indexPred = getPred(); + intrange::CmpPredicate pred = static_cast(indexPred); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + APInt min = APInt::getZero(1); + APInt max = APInt::getAllOnesValue(1); + + Optional truthValue64 = intrange::evaluatePred(pred, lhs, rhs); + + ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth), + rhsTrunc = truncRange(rhs, indexMinWidth); + Optional truthValue32 = + intrange::evaluatePred(pred, lhsTrunc, rhsTrunc); + + if (truthValue64 == truthValue32) { + if (truthValue64.has_value() && *truthValue64) + min = max; + else if (truthValue64.has_value() && !(*truthValue64)) + max = min; + } + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); +} + +//===----------------------------------------------------------------------===// +// SizeOf, which is bounded between the two supported bitwidth (32 and 64). +//===----------------------------------------------------------------------===// + +void SizeOfOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + unsigned storageWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + APInt min(/*numBits=*/storageWidth, indexMinWidth); + APInt max(/*numBits=*/storageWidth, indexMaxWidth); + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index a7cdbb5..38ad0e4 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -51,3 +51,5 @@ add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) + +add_subdirectory(Utils) diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt new file mode 100644 index 0000000..ece6c8e --- /dev/null +++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_library(MLIRInferIntRangeCommon + InferIntRangeCommon.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils + + DEPENDS + MLIRInferIntRangeInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRInferIntRangeInterface + MLIRIR +) diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp new file mode 100644 index 0000000..c81f004 --- /dev/null +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -0,0 +1,663 @@ +//===- InferIntRangeCommon.cpp - Inference for common ops ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains implementations of range inference for operations that are +// common to both the `arith` and `index` dialects to facilitate reuse. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" + +#include "mlir/Interfaces/InferIntRangeInterface.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" + +#include "llvm/Support/Debug.h" + +#include +#include + +using namespace mlir; + +#define DEBUG_TYPE "int-range-analysis" + +//===----------------------------------------------------------------------===// +// General utilities +//===----------------------------------------------------------------------===// + +/// Function that evaluates the result of doing something on arithmetic +/// constants and returns std::nullopt on overflow. +using ConstArithFn = + function_ref(const APInt &, const APInt &)>; + +/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, +/// If either computation overflows, make the result unbounded. +static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, + const APInt &minRight, + const APInt &maxLeft, + const APInt &maxRight, bool isSigned) { + std::optional maybeMin = op(minLeft, minRight); + std::optional maybeMax = op(maxLeft, maxRight); + if (maybeMin && maybeMax) + return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); + return ConstantIntRanges::maxRange(minLeft.getBitWidth()); +} + +/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, +/// ignoring unbounded values. Returns the maximal range if `op` overflows. +static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, + ArrayRef rhs, bool isSigned) { + unsigned width = lhs[0].getBitWidth(); + APInt min = + isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); + APInt max = + isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); + for (const APInt &left : lhs) { + for (const APInt &right : rhs) { + std::optional maybeThisResult = op(left, right); + if (!maybeThisResult) + return ConstantIntRanges::maxRange(width); + APInt result = std::move(*maybeThisResult); + min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; + max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; + } + } + return ConstantIntRanges::range(min, max, isSigned); +} + +//===----------------------------------------------------------------------===// +// Ext, trunc, index op handling +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferIndexOp(InferRangeFn inferFn, + ArrayRef argRanges, + intrange::CmpMode mode) { + ConstantIntRanges sixtyFour = inferFn(argRanges); + SmallVector truncated; + llvm::transform(argRanges, std::back_inserter(truncated), + [](const ConstantIntRanges &range) { + return truncRange(range, /*destWidth=*/indexMinWidth); + }); + ConstantIntRanges thirtyTwo = inferFn(truncated); + ConstantIntRanges thirtyTwoAsSixtyFour = + extRange(thirtyTwo, /*destWidth=*/indexMaxWidth); + ConstantIntRanges sixtyFourAsThirtyTwo = + truncRange(sixtyFour, /*destWidth=*/indexMinWidth); + + LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour + << " 32-bit = " << thirtyTwo << "\n"); + bool truncEqual = false; + switch (mode) { + case intrange::CmpMode::Both: + truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo); + break; + case intrange::CmpMode::Signed: + truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() && + thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax()); + break; + case intrange::CmpMode::Unsigned: + truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() && + thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax()); + break; + } + if (truncEqual) + // Returing the 64-bit result preserves more information. + return sixtyFour; + ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour); + return merged; +} + +ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, + unsigned int destWidth) { + APInt umin = range.umin().zext(destWidth); + APInt umax = range.umax().zext(destWidth); + APInt smin = range.smin().sext(destWidth); + APInt smax = range.smax().sext(destWidth); + return {umin, umax, smin, smax}; +} + +ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, + unsigned destWidth) { + APInt umin = range.umin().zext(destWidth); + APInt umax = range.umax().zext(destWidth); + return ConstantIntRanges::fromUnsigned(umin, umax); +} + +ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, + unsigned destWidth) { + APInt smin = range.smin().sext(destWidth); + APInt smax = range.smax().sext(destWidth); + return ConstantIntRanges::fromSigned(smin, smax); +} + +ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, + unsigned int destWidth) { + // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], + // the range of the resulting value is not contiguous ind includes 0. + // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], + // but you can't truncate [255, 257] similarly. + bool hasUnsignedRollover = + range.umin().lshr(destWidth) != range.umax().lshr(destWidth); + APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) + : range.umin().trunc(destWidth); + APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) + : range.umax().trunc(destWidth); + + // Signed post-truncation rollover will not occur when either: + // - The high parts of the min and max, plus the sign bit, are the same + // - The high halves + sign bit of the min and max are either all 1s or all 0s + // and you won't create a [positive, negative] range by truncating. + // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 + // but not [255, 257]_i16 to a range of i8s. You can also truncate + // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. + // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) + // will truncate to 0x7e, which is greater than 0 + APInt sminHighPart = range.smin().ashr(destWidth - 1); + APInt smaxHighPart = range.smax().ashr(destWidth - 1); + bool hasSignedOverflow = + (sminHighPart != smaxHighPart) && + !(sminHighPart.isAllOnes() && + (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && + !(sminHighPart.isZero() && smaxHighPart.isZero()); + APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) + : range.smin().trunc(destWidth); + APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) + : range.smax().trunc(destWidth); + return {umin, umax, smin, smax}; +} + +//===----------------------------------------------------------------------===// +// Addition +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferAdd(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + ConstArithFn uadd = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.uadd_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn sadd = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.sadd_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + + ConstantIntRanges urange = computeBoundsBy( + uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); + ConstantIntRanges srange = computeBoundsBy( + sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); + return urange.intersection(srange); +} + +//===----------------------------------------------------------------------===// +// Subtraction +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferSub(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn usub = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.usub_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn ssub = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.ssub_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstantIntRanges urange = computeBoundsBy( + usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); + ConstantIntRanges srange = computeBoundsBy( + ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); + return urange.intersection(srange); +} + +//===----------------------------------------------------------------------===// +// Multiplication +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferMul(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn umul = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.umul_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn smul = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.smul_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + + ConstantIntRanges urange = + minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); + ConstantIntRanges srange = + minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, + /*isSigned=*/true); + return urange.intersection(srange); +} + +//===----------------------------------------------------------------------===// +// DivU, CeilDivU (Unsigned division) +//===----------------------------------------------------------------------===// + +/// Fix up division results (ex. for ceiling and floor), returning an APInt +/// if there has been no overflow +using DivisionFixupFn = function_ref( + const APInt &lhs, const APInt &rhs, const APInt &result)>; + +static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), + &rhsMax = rhs.umax(); + + if (!rhsMin.isZero()) { + auto udiv = [&fixup](const APInt &a, + const APInt &b) -> std::optional { + return fixup(a, b, a.udiv(b)); + }; + return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/false); + } + // Otherwise, it's possible we might divide by 0. + return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); +} + +ConstantIntRanges +mlir::intrange::inferDivU(ArrayRef argRanges) { + return inferDivURange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; }); +} + +ConstantIntRanges +mlir::intrange::inferCeilDivU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivUIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.urem(rhs).isZero()) { + bool overflowed = false; + APInt corrected = + result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + return inferDivURange(lhs, rhs, ceilDivUIFix); +} + +//===----------------------------------------------------------------------===// +// DivS, CeilDivS, FloorDivS (Signed division) +//===----------------------------------------------------------------------===// + +static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { + const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), + &rhsMax = rhs.smax(); + bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); + + if (canDivide) { + auto sdiv = [&fixup](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.sdiv_ov(b, overflowed); + return overflowed ? std::optional() : fixup(a, b, result); + }; + return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/true); + } + return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); +} + +ConstantIntRanges +mlir::intrange::inferDivS(ArrayRef argRanges) { + return inferDivSRange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; }); +} + +ConstantIntRanges +mlir::intrange::inferCeilDivS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivSIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + return inferDivSRange(lhs, rhs, ceilDivSIFix); +} + +ConstantIntRanges +mlir::intrange::inferFloorDivS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn floorDivSIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + return inferDivSRange(lhs, rhs, floorDivSIFix); +} + +//===----------------------------------------------------------------------===// +// Signed remainder (RemS) +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferRemS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), + &rhsMax = rhs.smax(); + + unsigned width = rhsMax.getBitWidth(); + APInt smin = APInt::getSignedMinValue(width); + APInt smax = APInt::getSignedMaxValue(width); + // No bounds if zero could be a divisor. + bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); + if (canBound) { + APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); + bool canNegativeDividend = lhsMin.isNegative(); + bool canPositiveDividend = lhsMax.isStrictlyPositive(); + APInt zero = APInt::getZero(maxDivisor.getBitWidth()); + APInt maxPositiveResult = maxDivisor - 1; + APInt minNegativeResult = -maxPositiveResult; + smin = canNegativeDividend ? minNegativeResult : zero; + smax = canPositiveDividend ? maxPositiveResult : zero; + // Special case: sweeping out a contiguous range in N/[modulus]. + if (rhsMin == rhsMax) { + if ((lhsMax - lhsMin).ult(maxDivisor)) { + APInt minRem = lhsMin.srem(maxDivisor); + APInt maxRem = lhsMax.srem(maxDivisor); + if (minRem.sle(maxRem)) { + smin = minRem; + smax = maxRem; + } + } + } + } + return ConstantIntRanges::fromSigned(smin, smax); +} + +//===----------------------------------------------------------------------===// +// Unsigned remainder (RemU) +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferRemU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); + + unsigned width = rhsMin.getBitWidth(); + APInt umin = APInt::getZero(width); + APInt umax = APInt::getMaxValue(width); + + if (!rhsMin.isZero()) { + umax = rhsMax - 1; + // Special case: sweeping out a contiguous range in N/[modulus] + if (rhsMin == rhsMax) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); + if ((lhsMax - lhsMin).ult(rhsMax)) { + APInt minRem = lhsMin.urem(rhsMax); + APInt maxRem = lhsMax.urem(rhsMax); + if (minRem.ule(maxRem)) { + umin = minRem; + umax = maxRem; + } + } + } + } + return ConstantIntRanges::fromUnsigned(umin, umax); +} + +//===----------------------------------------------------------------------===// +// Max and min (MaxS, MaxU, MinS, MinU) +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferMaxS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); + const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); + return ConstantIntRanges::fromSigned(smin, smax); +} + +ConstantIntRanges +mlir::intrange::inferMaxU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); + const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); + return ConstantIntRanges::fromUnsigned(umin, umax); +} + +ConstantIntRanges +mlir::intrange::inferMinS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); + const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); + return ConstantIntRanges::fromSigned(smin, smax); +} + +ConstantIntRanges +mlir::intrange::inferMinU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); + const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); + return ConstantIntRanges::fromUnsigned(umin, umax); +} + +//===----------------------------------------------------------------------===// +// Bitwise operators (And, Or, Xor) +//===----------------------------------------------------------------------===// + +/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, +/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits +/// that both bonuds have in common. This gives us a consertive approximation +/// for what values can be passed to bitwise operations. +static std::tuple +widenBitwiseBounds(const ConstantIntRanges &bound) { + APInt leftVal = bound.umin(), rightVal = bound.umax(); + unsigned bitwidth = leftVal.getBitWidth(); + unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); + leftVal.clearLowBits(differingBits); + rightVal.setLowBits(differingBits); + return std::make_tuple(std::move(leftVal), std::move(rightVal)); +} + +ConstantIntRanges +mlir::intrange::inferAnd(ArrayRef argRanges) { + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto andi = [](const APInt &a, const APInt &b) -> std::optional { + return a & b; + }; + return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false); +} + +ConstantIntRanges +mlir::intrange::inferOr(ArrayRef argRanges) { + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto ori = [](const APInt &a, const APInt &b) -> std::optional { + return a | b; + }; + return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false); +} + +ConstantIntRanges +mlir::intrange::inferXor(ArrayRef argRanges) { + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto xori = [](const APInt &a, const APInt &b) -> std::optional { + return a ^ b; + }; + return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false); +} + +//===----------------------------------------------------------------------===// +// Shifts (Shl, ShrS, ShrU) +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferShl(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + ConstArithFn shl = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.shl(r); + }; + ConstantIntRanges urange = + minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); + ConstantIntRanges srange = + minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/true); + return urange.intersection(srange); +} + +ConstantIntRanges +mlir::intrange::inferShrS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn ashr = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); + }; + + return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/true); +} + +ConstantIntRanges +mlir::intrange::inferShrU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn lshr = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); + }; + return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); +} + +//===----------------------------------------------------------------------===// +// Comparisons (Cmp) +//===----------------------------------------------------------------------===// + +static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) { + switch (pred) { + case intrange::CmpPredicate::eq: + return intrange::CmpPredicate::ne; + case intrange::CmpPredicate::ne: + return intrange::CmpPredicate::eq; + case intrange::CmpPredicate::slt: + return intrange::CmpPredicate::sge; + case intrange::CmpPredicate::sle: + return intrange::CmpPredicate::sgt; + case intrange::CmpPredicate::sgt: + return intrange::CmpPredicate::sle; + case intrange::CmpPredicate::sge: + return intrange::CmpPredicate::slt; + case intrange::CmpPredicate::ult: + return intrange::CmpPredicate::uge; + case intrange::CmpPredicate::ule: + return intrange::CmpPredicate::ugt; + case intrange::CmpPredicate::ugt: + return intrange::CmpPredicate::ule; + case intrange::CmpPredicate::uge: + return intrange::CmpPredicate::ult; + } + llvm_unreachable("unknown cmp predicate value"); +} + +static bool isStaticallyTrue(intrange::CmpPredicate pred, + const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs) { + switch (pred) { + case intrange::CmpPredicate::sle: + return lhs.smax().sle(rhs.smin()); + case intrange::CmpPredicate::slt: + return lhs.smax().slt(rhs.smin()); + case intrange::CmpPredicate::ule: + return lhs.umax().ule(rhs.umin()); + case intrange::CmpPredicate::ult: + return lhs.umax().ult(rhs.umin()); + case intrange::CmpPredicate::sge: + return lhs.smin().sge(rhs.smax()); + case intrange::CmpPredicate::sgt: + return lhs.smin().sgt(rhs.smax()); + case intrange::CmpPredicate::uge: + return lhs.umin().uge(rhs.umax()); + case intrange::CmpPredicate::ugt: + return lhs.umin().ugt(rhs.umax()); + case intrange::CmpPredicate::eq: { + std::optional lhsConst = lhs.getConstantValue(); + std::optional rhsConst = rhs.getConstantValue(); + return lhsConst && rhsConst && lhsConst == rhsConst; + } + case intrange::CmpPredicate::ne: { + // While equality requires that there is an interpration of the preceeding + // computations that produces equal constants, whether that be signed or + // unsigned, statically determining inequality requires that neither + // interpretation produce potentially overlapping ranges. + bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) || + isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs); + bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) || + isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs); + return sne && une; + } + } + return false; +} + +std::optional mlir::intrange::evaluatePred(CmpPredicate pred, + const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs) { + if (isStaticallyTrue(pred, lhs, rhs)) + return true; + if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) + return false; + return std::nullopt; +} diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir new file mode 100644 index 0000000..2784d5f --- /dev/null +++ b/mlir/test/Dialect/Index/int-range-inference.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s + +// Most operations are covered by the `arith` tests, which use the same code +// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling +// code is operating as expected. + +// CHECK-LABEL: func @add_same_for_both +// CHECK: %[[true:.*]] = index.bool.constant true +// CHECK: return %[[true]] +func.func @add_same_for_both(%arg0 : index) -> i1 { + %c1 = index.constant 1 + %calmostBig = index.constant 0xfffffffe + %0 = index.minu %arg0, %calmostBig + %1 = index.add %0, %c1 + %2 = index.cmp uge(%1, %c1) + func.return %2 : i1 +} + +// CHECK-LABEL: func @add_unsigned_ov +// CHECK: %[[uge:.*]] = index.cmp uge +// CHECK: return %[[uge]] +func.func @add_unsigned_ov(%arg0 : index) -> i1 { + %c1 = index.constant 1 + %cu32_max = index.constant 0xffffffff + %0 = index.minu %arg0, %cu32_max + %1 = index.add %0, %c1 + // On 32-bit, the add could wrap, so the result doesn't have to be >= 1 + %2 = index.cmp uge(%1, %c1) + func.return %2 : i1 +} + +// CHECK-LABEL: func @add_signed_ov +// CHECK: %[[sge:.*]] = index.cmp sge +// CHECK: return %[[sge]] +func.func @add_signed_ov(%arg0 : index) -> i1 { + %c0 = index.constant 0 + %c1 = index.constant 1 + %ci32_max = index.constant 0x7fffffff + %0 = index.minu %arg0, %ci32_max + %1 = index.add %0, %c1 + // On 32-bit, the add could wrap, so the result doesn't have to be positive + %2 = index.cmp sge(%1, %c0) + func.return %2 : i1 +} + +// CHECK-LABEL: func @add_big +// CHECK: %[[true:.*]] = index.bool.constant true +// CHECK: return %[[true]] +func.func @add_big(%arg0 : index) -> i1 { + %c1 = index.constant 1 + %cmin = index.constant 0x300000000 + %cmax = index.constant 0x30000ffff + // Note: the order of the clamps matters. + // If you go max, then min, you infer the ranges [0x300...0, 0xff..ff] + // and then [0x30...0000, 0x30...ffff] + // If you switch the order of the below operations, you instead first infer + // the range [0,0x3...ffff]. Then, the min inference can't constraint + // this intermediate, since in the 32-bit case we could have, for example + // trunc(%arg0 = 0x2ffffffff) = 0xffffffff > trunc(0x30000ffff) = 0x0000ffff + // which means we can't do any inference. + %0 = index.maxu %arg0, %cmin + %1 = index.minu %0, %cmax + %2 = index.add %1, %c1 + %3 = index.cmp uge(%1, %cmin) + func.return %3 : i1 +}