From b0b00432093be9680ed833af642bcafc3ca11586 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Thu, 14 Apr 2022 22:51:23 +0000 Subject: [PATCH] [mlir][Arith] Pass to switch signed ops for equivalent unsigned ones If all the arguments to and results of an operation are known to be non-negative when interpreted as signed (which also implies that all computations producing those values did not experience signed overflow), we can replace that operation with an equivalent one that operates on unsigned values. Such a replacement, when it is possible, can provide useful hints to backends, such as by allowing LLVM to replace remainder with bitwise operations in more cases. Depends on D124022 Depends on D124023 Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D124024 --- .../mlir/Dialect/Arithmetic/Transforms/Passes.h | 4 + .../mlir/Dialect/Arithmetic/Transforms/Passes.td | 16 +++ .../Dialect/Arithmetic/Transforms/CMakeLists.txt | 3 + .../Transforms/UnsignedWhenEquivalent.cpp | 144 +++++++++++++++++++++ .../Arithmetic/unsigned-when-equivalent.mlir | 88 +++++++++++++ 5 files changed, 255 insertions(+) create mode 100644 mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp create mode 100644 mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h index 1acea57..9b9331f 100644 --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -26,6 +26,10 @@ void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns); /// Create a pass to legalize Arithmetic ops for LLVM lowering. std::unique_ptr createArithmeticExpandOpsPass(); +/// Create a pass to replace signed ops with unsigned ones where they are proven +/// equivalent. +std::unique_ptr createArithmeticUnsignedWhenEquivalentPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td index 1d84e27..752d715 100644 --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td @@ -33,4 +33,20 @@ def ArithmeticExpandOps : Pass<"arith-expand"> { let constructor = "mlir::arith::createArithmeticExpandOpsPass()"; } +def ArithmeticUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> { + let summary = "Replace signed ops with unsigned ones where they are proven equivalent"; + let description = [{ + Replace signed ops with their unsigned equivalents when integer range analysis + determines that their arguments and results are all guaranteed to be + non-negative when interpreted as signed integers. When this occurs, + we know that the semantics of the signed and unsigned operations are the same, + since they share the same behavior when their operands and results are in the + range [0, signed_max(type)]. + + The affect ops include division, remainder, shifts, min, max, and integer + comparisons. + }]; + let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()"; +} + #endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt index 7f5c9ca..f140715 100644 --- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp ExpandOps.cpp + UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic/Transforms @@ -10,9 +11,11 @@ add_mlir_dialect_library(MLIRArithmeticTransforms MLIRArithmeticTransformsIncGen LINK_LIBS PUBLIC + MLIRAnalysis MLIRArithmeticDialect MLIRBufferizationDialect MLIRBufferizationTransforms + MLIRInferIntRangeInterface MLIRIR MLIRMemRefDialect MLIRPass diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp new file mode 100644 index 0000000..30fb517 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp @@ -0,0 +1,144 @@ +//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with +// unsigned +// ones when all their arguments and results are statically non-negative --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::arith; + +using OpList = llvm::SmallVector; + +/// Returns true when a value is statically non-negative in that it has a lower +/// bound on its value (if it is treated as signed) and that bound is +/// non-negative. +static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) { + Optional result = analysis.getResult(v); + if (!result.hasValue()) + return false; + const ConstantIntRanges &range = result.getValue(); + return (range.smin().isNonNegative()); +} + +/// Identify all operations in a block that have signed equivalents and have +/// operands and results that are statically non-negative. +template +static void getConvertableOps(Operation *root, OpList &toRewrite, + IntRangeAnalysis &analysis) { + auto nonNegativePred = [&analysis](Value v) -> bool { + return staticallyNonNegative(analysis, v); + }; + root->walk([&nonNegativePred, &toRewrite](Operation *orig) { + if (isa(orig) && + llvm::all_of(orig->getOperands(), nonNegativePred) && + llvm::all_of(orig->getResults(), nonNegativePred)) { + toRewrite.push_back(orig); + } + }); +} + +static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { + switch (pred) { + case CmpIPredicate::sle: + return CmpIPredicate::ule; + case CmpIPredicate::slt: + return CmpIPredicate::ult; + case CmpIPredicate::sge: + return CmpIPredicate::uge; + case CmpIPredicate::sgt: + return CmpIPredicate::ugt; + default: + return pred; + } +} + +/// Find all cmpi ops that can be replaced by their unsigned equivalents. +static void getConvertableCmpi(Operation *root, OpList &toRewrite, + IntRangeAnalysis &analysis) { + auto nonNegativePred = [&analysis](Value v) -> bool { + return staticallyNonNegative(analysis, v); + }; + root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) { + CmpIPredicate pred = orig.getPredicate(); + if (toUnsignedPred(pred) != pred && + // i1 will spuriously and trivially show up as pontentially negative, + // so don't check the results + llvm::all_of(orig->getOperands(), nonNegativePred)) { + toRewrite.push_back(orig.getOperation()); + } + }); +} + +/// Return ops to be replaced in the order they should be rewritten. +static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) { + OpList ret; + getConvertableOps(root, ret, analysis); + // Since these are in-place changes, they don't need to be topological order + // like the others. + getConvertableCmpi(root, ret, analysis); + return ret; +} + +template +static void rewriteOp(Operation *op, OpBuilder &b) { + if (isa(op)) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(op); + Operation *newOp = b.create(op->getLoc(), op->getResultTypes(), + op->getOperands(), op->getAttrs()); + op->replaceAllUsesWith(newOp->getResults()); + op->erase(); + } +} + +static void rewriteCmpI(Operation *op, OpBuilder &b) { + if (auto cmpOp = dyn_cast(op)) { + cmpOp.setPredicateAttr(CmpIPredicateAttr::get( + b.getContext(), toUnsignedPred(cmpOp.getPredicate()))); + } +} + +static void rewrite(Operation *root, const OpList &toReplace) { + OpBuilder b(root->getContext()); + b.setInsertionPoint(root); + for (Operation *op : toReplace) { + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteCmpI(op, b); + } +} + +namespace { +struct ArithmeticUnsignedWhenEquivalentPass + : public ArithmeticUnsignedWhenEquivalentBase< + ArithmeticUnsignedWhenEquivalentPass> { + /// Implementation structure: first find all equivalent ops and collect them, + /// then perform all the rewrites in a second pass over the target op. This + /// ensures that analysis results are not invalidated during rewriting. + void runOnOperation() override { + Operation *op = getOperation(); + IntRangeAnalysis analysis(op); + rewrite(op, getMatching(op, analysis)); + } +}; +} // end anonymous namespace + +std::unique_ptr +mlir::arith::createArithmeticUnsignedWhenEquivalentPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir new file mode 100644 index 0000000..558c9f4 --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt -arith-unsigned-when-equivalent %s | FileCheck %s + +// CHECK-LABEL func @not_with_maybe_overflow +// CHECK: arith.divsi +// CHECK: arith.ceildivsi +// CHECK: arith.floordivsi +// CHECK: arith.remsi +// CHECK: arith.minsi +// CHECK: arith.maxsi +// CHECK: arith.extsi +// CHECK: arith.cmpi sle +// CHECK: arith.cmpi slt +// CHECK: arith.cmpi sge +// CHECK: arith.cmpi sgt +func.func @not_with_maybe_overflow(%arg0 : i32) { + %ci32_smax = arith.constant 0x7fffffff : i32 + %c1 = arith.constant 1 : i32 + %c4 = arith.constant 4 : i32 + %0 = arith.minui %arg0, %ci32_smax : i32 + %1 = arith.addi %0, %c1 : i32 + %2 = arith.divsi %1, %c4 : i32 + %3 = arith.ceildivsi %1, %c4 : i32 + %4 = arith.floordivsi %1, %c4 : i32 + %5 = arith.remsi %1, %c4 : i32 + %6 = arith.minsi %1, %c4 : i32 + %7 = arith.maxsi %1, %c4 : i32 + %8 = arith.extsi %1 : i32 to i64 + %9 = arith.cmpi sle, %1, %c4 : i32 + %10 = arith.cmpi slt, %1, %c4 : i32 + %11 = arith.cmpi sge, %1, %c4 : i32 + %12 = arith.cmpi sgt, %1, %c4 : i32 + func.return +} + +// CHECK-LABEL func @yes_with_no_overflow +// CHECK: arith.divui +// CHECK: arith.ceildivui +// CHECK: arith.divui +// CHECK: arith.remui +// CHECK: arith.minui +// CHECK: arith.maxui +// CHECK: arith.extui +// CHECK: arith.cmpi ule +// CHECK: arith.cmpi ult +// CHECK: arith.cmpi uge +// CHECK: arith.cmpi ugt +func.func @yes_with_no_overflow(%arg0 : i32) { + %ci32_almost_smax = arith.constant 0x7ffffffe : i32 + %c1 = arith.constant 1 : i32 + %c4 = arith.constant 4 : i32 + %0 = arith.minui %arg0, %ci32_almost_smax : i32 + %1 = arith.addi %0, %c1 : i32 + %2 = arith.divsi %1, %c4 : i32 + %3 = arith.ceildivsi %1, %c4 : i32 + %4 = arith.floordivsi %1, %c4 : i32 + %5 = arith.remsi %1, %c4 : i32 + %6 = arith.minsi %1, %c4 : i32 + %7 = arith.maxsi %1, %c4 : i32 + %8 = arith.extsi %1 : i32 to i64 + %9 = arith.cmpi sle, %1, %c4 : i32 + %10 = arith.cmpi slt, %1, %c4 : i32 + %11 = arith.cmpi sge, %1, %c4 : i32 + %12 = arith.cmpi sgt, %1, %c4 : i32 + func.return +} + +// CHECK-LABEL: func @preserves_structure +// CHECK: scf.for %[[arg1:.*]] = +// CHECK: %[[v:.*]] = arith.remui %[[arg1]] +// CHECK: %[[w:.*]] = arith.addi %[[v]], %[[v]] +// CHECK: %[[test:.*]] = arith.cmpi ule, %[[w]] +// CHECK: scf.if %[[test]] +// CHECK: memref.store %[[w]] +func.func @preserves_structure(%arg0 : memref<8xindex>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + scf.for %arg1 = %c0 to %c8 step %c1 { + %v = arith.remsi %arg1, %c4 : index + %w = arith.addi %v, %v : index + %test = arith.cmpi sle, %w, %c4 : index + scf.if %test { + memref.store %w, %arg0[%arg1] : memref<8xindex> + } + } + func.return +} -- 2.7.4