[mlir][arith] Optimize arith.cmpi based on integer range analysis.
authorIvan Butygin <ivan.butygin@gmail.com>
Fri, 23 Dec 2022 15:20:20 +0000 (16:20 +0100)
committerIvan Butygin <ivan.butygin@gmail.com>
Wed, 11 Jan 2023 11:15:58 +0000 (12:15 +0100)
Add a pass which do arith dialect ops optimization based on integer range analysis (only cmpi for now).

Differential Revision: https://reviews.llvm.org/D140629

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp [new file with mode: 0644]
mlir/test/Dialect/Arith/int-range-opts.mlir [new file with mode: 0644]

index d087ac6..257a62a 100644 (file)
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class DataFlowSolver;
+
 namespace arith {
 
 #define GEN_PASS_DECL
 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+#define GEN_PASS_DECL_ARITHINTRANGEOPTS
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
 
 class WideIntEmulationConverter;
 
@@ -44,6 +48,13 @@ std::unique_ptr<Pass> createArithExpandOpsPass();
 /// equivalent.
 std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
 
+/// Add patterns for int range based optimizations.
+void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
+                                           DataFlowSolver &solver);
+
+/// Create a pass which do optimizations based on integer range analysis.
+std::unique_ptr<Pass> createIntRangeOptimizationsPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
index 16ef294..ee561e6 100644 (file)
@@ -49,6 +49,15 @@ def ArithUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {
   let constructor = "mlir::arith::createArithUnsignedWhenEquivalentPass()";
 }
 
+def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
+  let summary = "Do optimizations based on integer range analysis";
+  let description = [{
+    This pass runs integer range analysis and apllies optimizations based on its
+    results. e.g. replace arith.cmpi with const if it can be inferred from
+    args ranges.
+  }];
+}
+
 def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
   let summary = "Emulate 2*N-bit integer operations using N-bit operations";
   let description = [{
index b45ae48..9f098f0 100644 (file)
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   Bufferize.cpp
   EmulateWideInt.cpp
   ExpandOps.cpp
+  IntRangeOptimizations.cpp
   UnsignedWhenEquivalent.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
new file mode 100644 (file)
index 0000000..7f34c0a
--- /dev/null
@@ -0,0 +1,183 @@
+//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::arith {
+#define GEN_PASS_DEF_ARITHINTRANGEOPTS
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+} // namespace mlir::arith
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::dataflow;
+
+/// Returns true if 2 integer ranges have intersection.
+static bool intersects(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
+           (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
+}
+
+static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  if (!intersects(lhs, rhs))
+    return false;
+
+  return failure();
+}
+
+static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  if (!intersects(lhs, rhs))
+    return true;
+
+  return failure();
+}
+
+static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  if (lhs.smax().slt(rhs.smin()))
+    return true;
+
+  if (lhs.smin().sge(rhs.smax()))
+    return false;
+
+  return failure();
+}
+
+static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  if (lhs.smax().sle(rhs.smin()))
+    return true;
+
+  if (lhs.smin().sgt(rhs.smax()))
+    return false;
+
+  return failure();
+}
+
+static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  return handleSlt(rhs, lhs);
+}
+
+static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  return handleSle(rhs, lhs);
+}
+
+static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  if (lhs.umax().ult(rhs.umin()))
+    return true;
+
+  if (lhs.umin().uge(rhs.umax()))
+    return false;
+
+  return failure();
+}
+
+static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  if (lhs.umax().ule(rhs.umin()))
+    return true;
+
+  if (lhs.umin().ugt(rhs.umax()))
+    return false;
+
+  return failure();
+}
+
+static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  return handleUlt(rhs, lhs);
+}
+
+static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+  return handleUle(rhs, lhs);
+}
+
+namespace {
+struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
+
+  ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
+      : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
+
+  LogicalResult matchAndRewrite(arith::CmpIOp op,
+                                PatternRewriter &rewriter) const override {
+    auto *lhsResult =
+        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
+    if (!lhsResult || lhsResult->getValue().isUninitialized())
+      return failure();
+
+    auto *rhsResult =
+        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
+    if (!rhsResult || rhsResult->getValue().isUninitialized())
+      return failure();
+
+    using HandlerFunc =
+        FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
+    std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
+        handlers{};
+    using Pred = arith::CmpIPredicate;
+    handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
+    handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
+    handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
+    handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
+    handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
+    handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
+    handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
+    handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
+    handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
+    handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
+
+    HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
+    if (!handler)
+      return failure();
+
+    ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
+    ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
+    FailureOr<bool> result = handler(lhsValue, rhsValue);
+
+    if (failed(result))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
+        op, static_cast<int64_t>(*result), /*width*/ 1);
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+};
+
+struct IntRangeOptimizationsPass
+    : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<IntegerRangeAnalysis>();
+    if (failed(solver.initializeAndRun(op)))
+      return signalPassFailure();
+
+    RewritePatternSet patterns(ctx);
+    populateIntRangeOptimizationsPatterns(patterns, solver);
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::arith::populateIntRangeOptimizationsPatterns(
+    RewritePatternSet &patterns, DataFlowSolver &solver) {
+  patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
+}
+
+std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
+  return std::make_unique<IntRangeOptimizationsPass>();
+}
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
new file mode 100644 (file)
index 0000000..be0a7e8
--- /dev/null
@@ -0,0 +1,73 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @test
+//       CHECK:   %[[C:.*]] = arith.constant false
+//       CHECK:   return %[[C]]
+func.func @test() -> i1 {
+  %cst1 = arith.constant -1 : index
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %1 = arith.cmpi eq, %0, %cst1 : index
+  return %1: i1
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+//       CHECK:   %[[C:.*]] = arith.constant true
+//       CHECK:   return %[[C]]
+func.func @test() -> i1 {
+  %cst1 = arith.constant -1 : index
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %1 = arith.cmpi ne, %0, %cst1 : index
+  return %1: i1
+}
+
+// -----
+
+
+// CHECK-LABEL: func @test
+//       CHECK:   %[[C:.*]] = arith.constant true
+//       CHECK:   return %[[C]]
+func.func @test() -> i1 {
+  %cst = arith.constant 0 : index
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %1 = arith.cmpi sge, %0, %cst : index
+  return %1: i1
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+//       CHECK:   %[[C:.*]] = arith.constant false
+//       CHECK:   return %[[C]]
+func.func @test() -> i1 {
+  %cst = arith.constant 0 : index
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %1 = arith.cmpi slt, %0, %cst : index
+  return %1: i1
+}
+
+// -----
+
+
+// CHECK-LABEL: func @test
+//       CHECK:   %[[C:.*]] = arith.constant true
+//       CHECK:   return %[[C]]
+func.func @test() -> i1 {
+  %cst1 = arith.constant -1 : index
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %1 = arith.cmpi sgt, %0, %cst1 : index
+  return %1: i1
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+//       CHECK:   %[[C:.*]] = arith.constant false
+//       CHECK:   return %[[C]]
+func.func @test() -> i1 {
+  %cst1 = arith.constant -1 : index
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %1 = arith.cmpi sle, %0, %cst1 : index
+  return %1: i1
+}