[ConstraintSystem] Add helpers to deal with linear constraints.
authorFlorian Hahn <flo@fhahn.com>
Fri, 11 Sep 2020 13:33:06 +0000 (14:33 +0100)
committerFlorian Hahn <flo@fhahn.com>
Fri, 11 Sep 2020 13:43:22 +0000 (14:43 +0100)
This patch introduces a new ConstraintSystem class, that maintains a set
of linear constraints and uses Fourier–Motzkin elimination to eliminate
constraints to check if there are solutions for the system.

It also adds a convert-constraint-log-to-z3.py script, which can parse
the debug output of the constraint system and convert it to a python
script that feeds the constraints into Z3 and checks if it produces the
same result as the LLVM implementation. This is for verification
purposes.

Reviewed By: spatel

Differential Revision: https://reviews.llvm.org/D84544

llvm/include/llvm/Analysis/ConstraintSystem.h [new file with mode: 0644]
llvm/lib/Analysis/CMakeLists.txt
llvm/lib/Analysis/ConstraintSystem.cpp [new file with mode: 0644]
llvm/unittests/Analysis/CMakeLists.txt
llvm/unittests/Analysis/ConstraintSystemTest.cpp [new file with mode: 0644]
llvm/utils/convert-constraint-log-to-z3.py [new file with mode: 0755]

diff --git a/llvm/include/llvm/Analysis/ConstraintSystem.h b/llvm/include/llvm/Analysis/ConstraintSystem.h
new file mode 100644 (file)
index 0000000..7de787c
--- /dev/null
@@ -0,0 +1,57 @@
+//===- ConstraintSystem.h -  A system of linear constraints. --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_CONSTRAINTSYSTEM_H
+#define LLVM_ANALYSIS_CONSTRAINTSYSTEM_H
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+
+#include <string>
+
+namespace llvm {
+
+class ConstraintSystem {
+  /// Current linear constraints in the system.
+  /// An entry of the form c0, c1, ... cn represents the following constraint:
+  ///   c0 >= v0 * c1 + .... + v{n-1} * cn
+  SmallVector<SmallVector<int64_t, 8>, 4> Constraints;
+
+  /// Current greatest common divisor for all coefficients in the system.
+  uint32_t GCD = 1;
+
+  // Eliminate constraints from the system using Fourier–Motzkin elimination.
+  bool eliminateUsingFM();
+
+  /// Print the constraints in the system, using \p Names as variable names.
+  void dump(ArrayRef<std::string> Names) const;
+
+  /// Print the constraints in the system, using x0...xn as variable names.
+  void dump() const;
+
+  /// Returns true if there may be a solution for the constraints in the system.
+  bool mayHaveSolutionImpl();
+
+public:
+  void addVariableRow(const SmallVector<int64_t, 8> &R) {
+    assert(Constraints.empty() || R.size() == Constraints.back().size());
+    for (const auto &C : R) {
+      auto A = std::abs(C);
+      GCD = APIntOps::GreatestCommonDivisor({32, (uint32_t)A}, {32, GCD})
+                .getZExtValue();
+    }
+    Constraints.push_back(R);
+  }
+
+  /// Returns true if there may be a solution for the constraints in the system.
+  bool mayHaveSolution();
+};
+} // namespace llvm
+
+#endif // LLVM_ANALYSIS_CONSTRAINTSYSTEM_H
index f50439b..78cc764 100644 (file)
@@ -39,6 +39,7 @@ add_llvm_component_library(LLVMAnalysis
   CodeMetrics.cpp
   ConstantFolding.cpp
   DDG.cpp
+  ConstraintSystem.cpp
   Delinearization.cpp
   DemandedBits.cpp
   DependenceAnalysis.cpp
diff --git a/llvm/lib/Analysis/ConstraintSystem.cpp b/llvm/lib/Analysis/ConstraintSystem.cpp
new file mode 100644 (file)
index 0000000..95fe6c9
--- /dev/null
@@ -0,0 +1,141 @@
+//===- ConstraintSytem.cpp - A system of linear constraints. ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ConstraintSystem.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Debug.h"
+
+#include <algorithm>
+#include <string>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "constraint-system"
+
+bool ConstraintSystem::eliminateUsingFM() {
+  // Implementation of Fourier–Motzkin elimination, with some tricks from the
+  // paper Pugh, William. "The Omega test: a fast and practical integer
+  // programming algorithm for dependence
+  //  analysis."
+  // Supercomputing'91: Proceedings of the 1991 ACM/
+  // IEEE conference on Supercomputing. IEEE, 1991.
+  assert(!Constraints.empty() &&
+         "should only be called for non-empty constraint systems");
+  unsigned NumVariables = Constraints[0].size();
+  SmallVector<SmallVector<int64_t, 8>, 4> NewSystem;
+
+  unsigned NumConstraints = Constraints.size();
+  uint32_t NewGCD = 1;
+  // FIXME do not use copy
+  for (unsigned R1 = 0; R1 < NumConstraints; R1++) {
+    if (Constraints[R1][1] == 0) {
+      SmallVector<int64_t, 8> NR;
+      NR.push_back(Constraints[R1][0]);
+      for (unsigned i = 2; i < NumVariables; i++) {
+        NR.push_back(Constraints[R1][i]);
+      }
+      NewSystem.push_back(std::move(NR));
+      continue;
+    }
+
+    // FIXME do not use copy
+    bool EliminatedInRow = false;
+    for (unsigned R2 = R1 + 1; R2 < NumConstraints; R2++) {
+      if (R1 == R2)
+        continue;
+
+      // FIXME: can we do better than just dropping things here?
+      if (Constraints[R2][1] == 0)
+        continue;
+
+      if ((Constraints[R1][1] < 0 && Constraints[R2][1] < 0) ||
+          (Constraints[R1][1] > 0 && Constraints[R2][1] > 0))
+        continue;
+
+      unsigned LowerR = R1;
+      unsigned UpperR = R2;
+      if (Constraints[UpperR][1] < 0)
+        std::swap(LowerR, UpperR);
+
+      SmallVector<int64_t, 8> NR;
+      for (unsigned I = 0; I < NumVariables; I++) {
+        if (I == 1)
+          continue;
+
+        int64_t M1, M2, N;
+        if (__builtin_mul_overflow(Constraints[UpperR][I],
+                                   ((-1) * Constraints[LowerR][1] / GCD), &M1))
+          return false;
+        if (__builtin_mul_overflow(Constraints[LowerR][I],
+                                   (Constraints[UpperR][1] / GCD), &M2))
+          return false;
+        if (__builtin_add_overflow(M1, M2, &N))
+          return false;
+        NR.push_back(N);
+
+        NewGCD = APIntOps::GreatestCommonDivisor({32, (uint32_t)NR.back()},
+                                                 {32, NewGCD})
+                     .getZExtValue();
+      }
+      NewSystem.push_back(std::move(NR));
+      EliminatedInRow = true;
+    }
+  }
+  Constraints = std::move(NewSystem);
+  GCD = NewGCD;
+
+  return true;
+}
+
+bool ConstraintSystem::mayHaveSolutionImpl() {
+  while (!Constraints.empty() && Constraints[0].size() > 1) {
+    if (!eliminateUsingFM())
+      return true;
+  }
+
+  if (Constraints.empty() || Constraints[0].size() > 1)
+    return true;
+
+  return all_of(Constraints, [](auto &R) { return R[0] >= 0; });
+}
+
+void ConstraintSystem::dump(ArrayRef<std::string> Names) const {
+  if (Constraints.empty())
+    return;
+
+  for (auto &Row : Constraints) {
+    SmallVector<std::string, 16> Parts;
+    for (unsigned I = 1, S = Row.size(); I < S; ++I) {
+      if (Row[I] == 0)
+        continue;
+      std::string Coefficient = "";
+      if (Row[I] != 1)
+        Coefficient = std::to_string(Row[I]) + " * ";
+      Parts.push_back(Coefficient + Names[I - 1]);
+    }
+    assert(!Parts.empty() && "need to have at least some parts");
+    LLVM_DEBUG(dbgs() << join(Parts, std::string(" + "))
+                      << " <= " << std::to_string(Row[0]) << "\n");
+  }
+}
+
+void ConstraintSystem::dump() const {
+  SmallVector<std::string, 16> Names;
+  for (unsigned i = 1; i < Constraints.back().size(); ++i)
+    Names.push_back("x" + std::to_string(i));
+  LLVM_DEBUG(dbgs() << "---\n");
+  dump(Names);
+}
+
+bool ConstraintSystem::mayHaveSolution() {
+  dump();
+  bool HasSolution = mayHaveSolutionImpl();
+  LLVM_DEBUG(dbgs() << (HasSolution ? "sat" : "unsat") << "\n");
+  return HasSolution;
+}
index eb97f62..dfe570f 100644 (file)
@@ -23,6 +23,7 @@ add_llvm_unittest_with_input_files(AnalysisTests
   CaptureTrackingTest.cpp
   CFGTest.cpp
   CGSCCPassManagerTest.cpp
+  ConstraintSystemTest.cpp
   DDGTest.cpp
   DivergenceAnalysisTest.cpp
   DomTreeUpdaterTest.cpp
diff --git a/llvm/unittests/Analysis/ConstraintSystemTest.cpp b/llvm/unittests/Analysis/ConstraintSystemTest.cpp
new file mode 100644 (file)
index 0000000..2301da7
--- /dev/null
@@ -0,0 +1,82 @@
+//===--- ConstraintSystemTests.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/ConstraintSystem.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+TEST(ConstraintSloverTest, TestSolutionChecks) {
+  {
+    ConstraintSystem CS;
+    // x + y <= 10, x >= 5, y >= 6, x <= 10, y <= 10
+    CS.addVariableRow({10, 1, 1});
+    CS.addVariableRow({-5, -1, 0});
+    CS.addVariableRow({-6, 0, -1});
+    CS.addVariableRow({10, 1, 0});
+    CS.addVariableRow({10, 0, 1});
+
+    EXPECT_FALSE(CS.mayHaveSolution());
+  }
+
+  {
+    ConstraintSystem CS;
+    // x + y <= 10, x >= 2, y >= 3, x <= 10, y <= 10
+    CS.addVariableRow({10, 1, 1});
+    CS.addVariableRow({-2, -1, 0});
+    CS.addVariableRow({-3, 0, -1});
+    CS.addVariableRow({10, 1, 0});
+    CS.addVariableRow({10, 0, 1});
+
+    EXPECT_TRUE(CS.mayHaveSolution());
+  }
+
+  {
+    ConstraintSystem CS;
+    // x + y <= 10, 10 >= x, 10 >= y; does not have a solution.
+    CS.addVariableRow({10, 1, 1});
+    CS.addVariableRow({-10, -1, 0});
+    CS.addVariableRow({-10, 0, -1});
+
+    EXPECT_FALSE(CS.mayHaveSolution());
+  }
+
+  {
+    ConstraintSystem CS;
+    // x + y >= 20, 10 >= x, 10 >= y; does HAVE a solution.
+    CS.addVariableRow({-20, -1, -1});
+    CS.addVariableRow({-10, -1, 0});
+    CS.addVariableRow({-10, 0, -1});
+
+    EXPECT_TRUE(CS.mayHaveSolution());
+  }
+
+  {
+    ConstraintSystem CS;
+
+    // 2x + y + 3z <= 10,  2x + y >= 10, y >= 1
+    CS.addVariableRow({10, 2, 1, 3});
+    CS.addVariableRow({-10, -2, -1, 0});
+    CS.addVariableRow({-1, 0, 0, -1});
+
+    EXPECT_FALSE(CS.mayHaveSolution());
+  }
+
+  {
+    ConstraintSystem CS;
+
+    // 2x + y + 3z <= 10,  2x + y >= 10
+    CS.addVariableRow({10, 2, 1, 3});
+    CS.addVariableRow({-10, -2, -1, 0});
+
+    EXPECT_TRUE(CS.mayHaveSolution());
+  }
+}
+} // namespace
diff --git a/llvm/utils/convert-constraint-log-to-z3.py b/llvm/utils/convert-constraint-log-to-z3.py
new file mode 100755 (executable)
index 0000000..77b0a3d
--- /dev/null
@@ -0,0 +1,69 @@
+#!/usr/bin/env python
+
+"""
+Helper script to convert the log generated by '-debug-only=constraint-system'
+to a Python script that uses Z3 to verify the decisions using Z3's Python API.
+
+Example usage:
+
+> cat path/to/file.log
+---
+x6 + -1 * x7 <= -1
+x6 + -1 * x7 <= -2
+sat
+
+> ./convert-constraint-log-to-z3.py path/to/file.log > check.py && python ./check.py
+
+> cat check.py
+    from z3 import *
+x3 = Int("x3")
+x1 = Int("x1")
+x2 = Int("x2")
+s = Solver()
+s.add(x1 + -1 * x2 <= 0)
+s.add(x2 + -1 * x3 <= 0)
+s.add(-1 * x1 + x3 <= -1)
+assert(s.check() == unsat)
+print('all checks passed')
+"""
+
+
+import argparse
+import re
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description='Convert constraint log to script to verify using Z3.')
+    parser.add_argument('log_file', metavar='log', type=str,
+                        help='constraint-system log file')
+    args = parser.parse_args()
+
+    content = ''
+    with open(args.log_file, 'rt') as f:
+        content = f.read()
+
+    groups = content.split('---')
+    var_re = re.compile('x\d+')
+
+    print('from z3 import *')
+    for group in groups:
+        constraints = [g.strip() for g in group.split('\n') if g.strip() != '']
+        variables = set()
+        for c in constraints[:-1]:
+            for m in var_re.finditer(c):
+                variables.add(m.group())
+        if len(variables) == 0:
+            continue
+        for v in variables:
+            print('{} = Int("{}")'.format(v, v))
+        print('s = Solver()')
+        for c in constraints[:-1]:
+            print('s.add({})'.format(c))
+        expected = constraints[-1].strip()
+        print('assert(s.check() == {})'.format(expected))
+    print('print("all checks passed")')
+
+
+if __name__ == '__main__':
+    main()