[mlir] FlatAffineConstraint parsing for unit tests
authorChristian Ulmann <culmann@student.ethz.ch>
Sun, 14 Nov 2021 18:19:19 +0000 (23:49 +0530)
committerGroverkss <groverkss@gmail.com>
Sun, 14 Nov 2021 18:20:38 +0000 (23:50 +0530)
This patch adds functionality to parse FlatAffineConstraints from a
StringRef with the intention to be used for unit tests. This should
make the construction of FlatAffineConstraints easier for testing
purposes.

The patch contains an example usage of the functionality in a unit test that
uses FlatAffineConstraints.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D113275

mlir/include/mlir/Analysis/AffineStructures.h
mlir/include/mlir/Parser.h
mlir/lib/Analysis/AffineStructures.cpp
mlir/lib/Analysis/CMakeLists.txt
mlir/lib/Parser/AffineParser.cpp
mlir/unittests/Analysis/AffineStructuresParserTest.cpp [new file with mode: 0644]
mlir/unittests/Analysis/AffineStructuresTest.cpp
mlir/unittests/Analysis/CMakeLists.txt

index 0c351c7..8d0c915 100644 (file)
@@ -1071,6 +1071,13 @@ LogicalResult getRelationFromMap(AffineMap &map, FlatAffineRelation &rel);
 LogicalResult getRelationFromMap(const AffineValueMap &map,
                                  FlatAffineRelation &rel);
 
+/// This parses a single IntegerSet to an MLIR context and transforms it to
+/// FlatAffineConstraints if it was valid. If not, a failure is returned. If the
+/// passed `str` has additional tokens that were not part of the IntegerSet, a
+/// failure is returned.
+FailureOr<FlatAffineConstraints> parseIntegerSetToFAC(llvm::StringRef,
+                                                      MLIRContext *context);
+
 } // end namespace mlir.
 
 #endif // MLIR_ANALYSIS_AFFINESTRUCTURES_H
index a534dfc..4236a40 100644 (file)
@@ -256,6 +256,14 @@ Type parseType(llvm::StringRef typeStr, MLIRContext *context);
 /// `typeStr`. The number of characters of `typeStr` parsed in the process is
 /// returned in `numRead`.
 Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t &numRead);
+
+/// This parses a single IntegerSet to an MLIR context if it was valid. If not,
+/// an error message is emitted through a new SourceMgrDiagnosticHandler
+/// constructed from a new SourceMgr with a single MemoryBuffer wrapping
+/// `str`. If the passed `str` has additional tokens that were not part of the
+/// IntegerSet, a failure is returned.
+IntegerSet parseIntegerSet(llvm::StringRef str, MLIRContext *context);
+
 } // end namespace mlir
 
 #endif // MLIR_PARSER_H
index 4e23b24..e942a38 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/Parser.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/STLExtras.h"
@@ -3865,3 +3866,13 @@ LogicalResult mlir::getRelationFromMap(const AffineValueMap &map,
 
   return success();
 }
+
+FailureOr<FlatAffineConstraints>
+mlir::parseIntegerSetToFAC(llvm::StringRef str, MLIRContext *context) {
+  IntegerSet set = parseIntegerSet(str, context);
+
+  if (!set)
+    return failure();
+
+  return FlatAffineConstraints(set);
+}
index 9128ef9..e4f67ac 100644 (file)
@@ -66,6 +66,7 @@ add_mlir_library(MLIRLoopAnalysis
   MLIRCallInterfaces
   MLIRControlFlowInterfaces
   MLIRInferTypeOpInterface
+  MLIRParser
   MLIRPresburger
   MLIRSCF
   )
index 7089837..86f2c80 100644 (file)
 #include "Parser.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/IntegerSet.h"
+#include "llvm/Support/SourceMgr.h"
 
 using namespace mlir;
 using namespace mlir::detail;
+using llvm::MemoryBuffer;
 using llvm::SMLoc;
+using llvm::SourceMgr;
 
 namespace {
 
@@ -717,3 +720,26 @@ Parser::parseAffineExprOfSSAIds(AffineExpr &expr,
   return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
       .parseAffineExprOfSSAIds(expr);
 }
+
+IntegerSet mlir::parseIntegerSet(StringRef inputStr, MLIRContext *context) {
+  llvm::SourceMgr sourceMgr;
+  auto memBuffer = llvm::MemoryBuffer::getMemBuffer(
+      inputStr, /*BufferName=*/"<mlir_parser_buffer>",
+      /*RequiresNullTerminator=*/false);
+  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+  SymbolState symbolState;
+  ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr);
+  Parser parser(state);
+  SourceMgrDiagnosticHandler handler(sourceMgr, context);
+  IntegerSet set;
+  if (parser.parseIntegerSetReference(set))
+    return IntegerSet();
+
+  Token endTok = parser.getToken();
+  if (endTok.isNot(Token::eof)) {
+    parser.emitError(endTok.getLoc(), "encountered unexpected token");
+    return IntegerSet();
+  }
+
+  return set;
+}
diff --git a/mlir/unittests/Analysis/AffineStructuresParserTest.cpp b/mlir/unittests/Analysis/AffineStructuresParserTest.cpp
new file mode 100644 (file)
index 0000000..4cdba7f
--- /dev/null
@@ -0,0 +1,137 @@
+//===- AffineStructuresParserTest.cpp - FAC parsing unit tests --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains tests for parsing IntegerSets to FlatAffineConstraints.
+// The tests with invalid input check that the parser only accepts well-formed
+// IntegerSets. The tests with well-formed input compare the returned FACs to
+// manually constructed FACs with a PresburgerSet equality check.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/PresburgerSet.h"
+
+#include <gtest/gtest.h>
+
+namespace mlir {
+
+/// Construct a FlatAffineConstraints from a set of inequality, equality, and
+/// division onstraints.
+static FlatAffineConstraints makeFACFromConstraints(
+    unsigned dims, unsigned syms, ArrayRef<SmallVector<int64_t, 4>> ineqs,
+    ArrayRef<SmallVector<int64_t, 4>> eqs = {},
+    ArrayRef<std::pair<SmallVector<int64_t, 4>, int64_t>> divs = {}) {
+  FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + syms + 1, dims,
+                            syms, 0);
+  for (const auto &div : divs)
+    fac.addLocalFloorDiv(div.first, div.second);
+  for (const auto &eq : eqs)
+    fac.addEquality(eq);
+  for (const auto &ineq : ineqs)
+    fac.addInequality(ineq);
+  return fac;
+}
+
+TEST(ParseFACTest, InvalidInputTest) {
+  MLIRContext context;
+  FailureOr<FlatAffineConstraints> fac;
+
+  fac = parseIntegerSetToFAC("(x)", &context);
+  EXPECT_TRUE(failed(fac))
+      << "should not accept strings with no constraint list";
+
+  fac = parseIntegerSetToFAC("(x)[] : ())", &context);
+  EXPECT_TRUE(failed(fac))
+      << "should not accept strings that contain remaining characters";
+
+  fac = parseIntegerSetToFAC("(x)[] : (x - >= 0)", &context);
+  EXPECT_TRUE(failed(fac))
+      << "should not accept strings that contain incomplete constraints";
+
+  fac = parseIntegerSetToFAC("(x)[] : (y == 0)", &context);
+  EXPECT_TRUE(failed(fac))
+      << "should not accept strings that contain unkown identifiers";
+
+  fac = parseIntegerSetToFAC("(x, x) : (2 * x >= 0)", &context);
+  EXPECT_TRUE(failed(fac))
+      << "should not accept strings that contain repeated identifier names";
+
+  fac = parseIntegerSetToFAC("(x)[x] : (2 * x >= 0)", &context);
+  EXPECT_TRUE(failed(fac))
+      << "should not accept strings that contain repeated identifier names";
+
+  fac = parseIntegerSetToFAC("(x) : (2 * x + 9223372036854775808 >= 0)",
+                             &context);
+  EXPECT_TRUE(failed(fac)) << "should not accept strings with integer literals "
+                              "that do not fit into int64_t";
+}
+
+/// Parses and compares the `str` to the `ex`. The equality check is performed
+/// by using PresburgerSet::isEqual
+static bool parseAndCompare(StringRef str, FlatAffineConstraints ex,
+                            MLIRContext *context) {
+  FailureOr<FlatAffineConstraints> fac = parseIntegerSetToFAC(str, context);
+
+  EXPECT_TRUE(succeeded(fac));
+
+  return PresburgerSet(*fac).isEqual(PresburgerSet(ex));
+}
+
+TEST(ParseFACTest, ParseAndCompareTest) {
+  MLIRContext context;
+  // simple ineq
+  EXPECT_TRUE(parseAndCompare(
+      "(x)[] : (x >= 0)", makeFACFromConstraints(1, 0, {{1, 0}}), &context));
+
+  // simple eq
+  EXPECT_TRUE(parseAndCompare("(x)[] : (x == 0)",
+                              makeFACFromConstraints(1, 0, {}, {{1, 0}}),
+                              &context));
+
+  // multiple constraints
+  EXPECT_TRUE(parseAndCompare("(x)[] : (7 * x >= 0, -7 * x + 5 >= 0)",
+                              makeFACFromConstraints(1, 0, {{7, 0}, {-7, 5}}),
+                              &context));
+
+  // multiple dimensions
+  EXPECT_TRUE(parseAndCompare("(x,y,z)[] : (x + y - z >= 0)",
+                              makeFACFromConstraints(3, 0, {{1, 1, -1, 0}}),
+                              &context));
+
+  // dimensions and symbols
+  EXPECT_TRUE(parseAndCompare(
+      "(x,y,z)[a,b] : (x + y - z + 2 * a - 15 * b >= 0)",
+      makeFACFromConstraints(3, 2, {{1, 1, -1, 2, -15, 0}}), &context));
+
+  // only symbols
+  EXPECT_TRUE(parseAndCompare("()[a] : (2 * a - 4 == 0)",
+                              makeFACFromConstraints(0, 1, {}, {{2, -4}}),
+                              &context));
+
+  // simple floordiv
+  EXPECT_TRUE(parseAndCompare(
+      "(x, y) : (y - 3 * ((x + y - 13) floordiv 3) - 42 == 0)",
+      makeFACFromConstraints(2, 0, {}, {{0, 1, -3, -42}}, {{{1, 1, -13}, 3}}),
+      &context));
+
+  // multiple floordiv
+  EXPECT_TRUE(parseAndCompare(
+      "(x, y) : (y - x floordiv 3 - y floordiv 2 == 0)",
+      makeFACFromConstraints(2, 0, {}, {{0, 1, -1, -1, 0}},
+                             {{{1, 0, 0}, 3}, {{0, 1, 0, 0}, 2}}),
+      &context));
+
+  // nested floordiv
+  EXPECT_TRUE(parseAndCompare(
+      "(x, y) : (y - (x + y floordiv 2) floordiv 3 == 0)",
+      makeFACFromConstraints(2, 0, {}, {{0, 1, 0, -1, 0}},
+                             {{{0, 1, 0}, 2}, {{1, 0, 1, 0}, 3}}),
+      &context));
+}
+
+} // namespace mlir
index ef83b7c..b43e267 100644 (file)
@@ -98,11 +98,24 @@ static void checkPermutationsSample(bool hasSample, unsigned nDim,
   } while (std::next_permutation(perm.begin(), perm.end()));
 }
 
+/// Parses a FlatAffineConstraints from a StringRef. It is expected that the
+/// string represents a valid IntegerSet, otherwise it will violate a gtest
+/// assertion.
+static FlatAffineConstraints parseFAC(StringRef str, MLIRContext *context) {
+  FailureOr<FlatAffineConstraints> fac = parseIntegerSetToFAC(str, context);
+
+  EXPECT_TRUE(succeeded(fac));
+
+  return *fac;
+}
+
 TEST(FlatAffineConstraintsTest, FindSampleTest) {
   // Bounded sets with only inequalities.
 
+  MLIRContext context;
+
   // 0 <= 7x <= 5
-  checkSample(true, makeFACFromConstraints(1, {{7, 0}, {-7, 5}}, {}));
+  checkSample(true, parseFAC("(x) : (7 * x >= 0, -7 * x + 5 >= 0)", &context));
 
   // 1 <= 5x and 5x <= 4 (no solution).
   checkSample(false, makeFACFromConstraints(1, {{5, -1}, {-5, 4}}, {}));
index 0df0af8..05a492d 100644 (file)
@@ -1,4 +1,5 @@
 add_mlir_unittest(MLIRAnalysisTests
+  AffineStructuresParserTest.cpp
   AffineStructuresTest.cpp
   LinearTransformTest.cpp
   PresburgerSetTest.cpp