[ODS] Support region names and constraints
authorLei Zhang <antiagainst@google.com>
Thu, 30 May 2019 23:50:16 +0000 (16:50 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:11:42 +0000 (20:11 -0700)
    Similar to arguments and results, now we require region definition in ops to
    be specified as a DAG expression with the 'region' operator. This way we can
    specify the constraints for each region and optionally give the region a name.

    Two kinds of region constraints are added, one allowing any region, and the
    other requires a certain number of blocks.

--

PiperOrigin-RevId: 250790211

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/SPIRV/SPIRVStructureOps.td
mlir/include/mlir/TableGen/Constraint.h
mlir/include/mlir/TableGen/Operator.h
mlir/include/mlir/TableGen/Region.h [new file with mode: 0644]
mlir/lib/TableGen/Constraint.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/IR/region.mlir
mlir/test/TestDialect/TestOps.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index ce1a87b..f046fd2 100644 (file)
@@ -135,6 +135,8 @@ class Concat<string pre, Pred child, string suf> :
 // Constraint definitions
 //===----------------------------------------------------------------------===//
 
+// TODO(b/130064155): Merge Constraints into Pred.
+
 // Base class for named constraints.
 //
 // An op's operands/attributes/results can have various requirements, e.g.,
@@ -170,6 +172,10 @@ class TypeConstraint<Pred predicate, string description = ""> :
 class AttrConstraint<Pred predicate, string description = ""> :
     Constraint<predicate, description>;
 
+// Subclass for constraints on a region.
+class RegionConstraint<Pred predicate, string description = ""> :
+    Constraint<predicate, description>;
+
 // How to use these constraint categories:
 //
 // * Use TypeConstraint to specify
@@ -796,6 +802,21 @@ def IsNullAttr : AttrConstraint<
     CPred<"!$_self">, "empty attribute (for optional attributes)">;
 
 //===----------------------------------------------------------------------===//
+// Region definitions
+//===----------------------------------------------------------------------===//
+
+class Region<Pred condition, string descr = ""> :
+    RegionConstraint<condition, descr>;
+
+// Any region.
+def AnyRegion : Region<CPred<"true">, "any region">;
+
+// A region with the given number of blocks.
+class SizedRegion<int numBlocks> : Region<
+  CPred<"$_self.getBlocks().size() == " # numBlocks>,
+  "region with " # numBlocks # " blocks">;
+
+//===----------------------------------------------------------------------===//
 // OpTrait definitions
 //===----------------------------------------------------------------------===//
 
@@ -869,6 +890,9 @@ def ins;
 // Marker used to identify the result list for an op.
 def outs;
 
+// Marker used to identify the region list for an op.
+def region;
+
 // Class for defining a custom builder.
 //
 // TableGen generates several generic builders for each op by default (see
@@ -916,9 +940,8 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
   // The list of results of the op. Default to 0 results.
   dag results = (outs);
 
-  // How many regions this op has.
-  // TODO(b/133479568): Enhance to support advanced region usage cases
-  int numRegions = 0;
+  // The list of regions of the op. Default to 0 regions.
+  dag regions = (region);
 
   // Attribute getters can be added to the op by adding an Attr member
   // with the name and type of the attribute. E.g., adding int attribute
index 7fef351..bcb485c 100644 (file)
@@ -72,7 +72,7 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
 
   let results = (outs);
 
-  let numRegions = 1;
+  let regions = (region AnyRegion:$body);
 
   // Custom parser and printer implemented by static functions in SPVOps.cpp
   let parser = [{ return parseModule(parser, result); }];
index f8b12d9..bcf207e 100644 (file)
@@ -57,7 +57,7 @@ public:
   StringRef getDescription() const;
 
   // Constraint kind
-  enum Kind { CK_Type, CK_Attr, CK_Uncategorized };
+  enum Kind { CK_Attr, CK_Region, CK_Type, CK_Uncategorized };
 
   Kind getKind() const { return kind; }
 
index 77b3a9f..de2818e 100644 (file)
@@ -27,6 +27,7 @@
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Dialect.h"
 #include "mlir/TableGen/OpTrait.h"
+#include "mlir/TableGen/Region.h"
 #include "mlir/TableGen/Type.h"
 #include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/SmallVector.h"
@@ -129,8 +130,15 @@ public:
   // requiring the raw MLIR trait here.
   bool hasTrait(llvm::StringRef trait) const;
 
+  using const_region_iterator = const NamedRegion *;
+  const_region_iterator region_begin() const;
+  const_region_iterator region_end() const;
+  llvm::iterator_range<const_region_iterator> getRegions() const;
+
   // Returns the number of regions.
-  int getNumRegions() const;
+  unsigned getNumRegions() const;
+  // Returns the `index`-th region.
+  const NamedRegion &getRegion(unsigned index) const;
 
   // Trait.
   using const_trait_iterator = const OpTrait *;
@@ -177,8 +185,8 @@ private:
   // The traits of the op.
   SmallVector<OpTrait, 4> traits;
 
-  // The number of regions of this op.
-  int numRegions = 0;
+  // The regions of this op.
+  SmallVector<NamedRegion, 1> regions;
 
   // The number of native attributes stored in the leading positions of
   // `attributes`.
diff --git a/mlir/include/mlir/TableGen/Region.h b/mlir/include/mlir/TableGen/Region.h
new file mode 100644 (file)
index 0000000..21dffe6
--- /dev/null
@@ -0,0 +1,45 @@
+//===- TGRegion.h - TableGen region definitions -----------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_TABLEGEN_REGION_H_
+#define MLIR_TABLEGEN_REGION_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Constraint.h"
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class providing helper methods for accessing Region defined in
+// TableGen.
+class Region : public Constraint {
+public:
+  using Constraint::Constraint;
+
+  static bool classof(const Constraint *c) { return c->getKind() == CK_Region; }
+};
+
+// A struct bundling a region's constraint and its name.
+struct NamedRegion {
+  StringRef name;
+  Region constraint;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_REGION_H_
index 2656f4c..96f49bf 100644 (file)
@@ -30,6 +30,8 @@ Constraint::Constraint(const llvm::Record *record)
     kind = CK_Type;
   } else if (record->isSubClassOf("AttrConstraint")) {
     kind = CK_Attr;
+  } else if (record->isSubClassOf("RegionConstraint")) {
+    kind = CK_Region;
   } else {
     assert(record->isSubClassOf("Constraint"));
   }
index d27db0d..cd3537d 100644 (file)
@@ -146,7 +146,24 @@ bool tblgen::Operator::hasTrait(StringRef trait) const {
   return false;
 }
 
-int tblgen::Operator::getNumRegions() const { return numRegions; }
+tblgen::Operator::const_region_iterator tblgen::Operator::region_begin() const {
+  return regions.begin();
+}
+
+tblgen::Operator::const_region_iterator tblgen::Operator::region_end() const {
+  return regions.end();
+}
+
+llvm::iterator_range<tblgen::Operator::const_region_iterator>
+tblgen::Operator::getRegions() const {
+  return {region_begin(), region_end()};
+}
+
+unsigned tblgen::Operator::getNumRegions() const { return regions.size(); }
+
+const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
+  return regions[index];
+}
 
 auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
   return traits.begin();
@@ -269,9 +286,21 @@ void tblgen::Operator::populateOpStructure() {
     traits.push_back(OpTrait::create(traitInit));
 
   // Handle regions
-  numRegions = def.getValueAsInt("numRegions");
-  if (numRegions < 0)
-    PrintFatalError(def.getLoc(), "numRegions cannot be negative");
+  auto *regionsDag = def.getValueAsDag("regions");
+  auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
+  if (!regionsOp || regionsOp->getDef()->getName() != "region") {
+    PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
+  }
+
+  for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
+    auto name = regionsDag->getArgNameStr(i);
+    auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
+    if (!regionInit) {
+      PrintFatalError(def.getLoc(),
+                      Twine("undefined kind for region #") + Twine(i));
+    }
+    regions.push_back({name, Region(regionInit->getDef())});
+  }
 }
 
 ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
index 702e56d..03b366c 100644 (file)
@@ -1,5 +1,9 @@
 // RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// Test the number of regions
+//===----------------------------------------------------------------------===//
+
 func @correct_number_of_regions() {
     // CHECK: test.two_region_op
     "test.two_region_op"()(
@@ -11,7 +15,7 @@ func @correct_number_of_regions() {
 
 // -----
 
-func @missingk_regions() {
+func @missing_regions() {
     // expected-error@+1 {{op has incorrect number of regions: expected 2 but found 1}}
     "test.two_region_op"()(
       {"work"() : () -> ()}
@@ -30,3 +34,42 @@ func @extra_regions() {
     ) : () -> ()
     return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test SizedRegion
+//===----------------------------------------------------------------------===//
+
+func @unnamed_region_has_wrong_number_of_blocks() {
+    // expected-error@+1 {{region #1 failed to verify constraint: region with 1 blocks}}
+    "test.sized_region_op"() (
+    {
+        "work"() : () -> ()
+        br ^next1
+      ^next1:
+        "work"() : () -> ()
+    },
+    {
+        "work"() : () -> ()
+        br ^next2
+      ^next2:
+        "work"() : () -> ()
+    }) : () -> ()
+    return
+}
+
+// -----
+
+// Test region name in error message
+func @named_region_has_wrong_number_of_blocks() {
+    // expected-error@+1 {{region #0 ('my_region') failed to verify constraint: region with 2 blocks}}
+    "test.sized_region_op"() (
+    {
+        "work"() : () -> ()
+    },
+    {
+        "work"() : () -> ()
+    }) : () -> ()
+    return
+}
index 5ffbcbc..814bc72 100644 (file)
@@ -129,7 +129,11 @@ def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
 //===----------------------------------------------------------------------===//
 
 def TwoRegionOp : TEST_Op<"two_region_op", []> {
-  let numRegions = 2;
+  let regions = (region AnyRegion, AnyRegion);
+}
+
+def SizedRegionOp : TEST_Op<"sized_region_op", []> {
+  let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>);
 }
 
 #endif // TEST_OPS
index 1dc9a95..a7b347e 100644 (file)
@@ -368,6 +368,10 @@ private:
   // Generates verify method for the operation.
   void genVerifier();
 
+  // Generates verify statements for regions in the operation.
+  // The generated code will be attached to `body`.
+  void genRegionVerifier(OpMethodBody &body);
+
   // Generates the traits used by the object.
   void genTraits();
 
@@ -388,12 +392,17 @@ private:
 
   // The C++ code builder for this op
   OpClass opClass;
+
+  // The format context for verification code generation.
+  FmtContext verifyCtx;
 };
 } // end anonymous namespace
 
 OpEmitter::OpEmitter(const Record &def)
     : def(def), op(def),
       opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
+  verifyCtx.withOp("(*this->getOperation())");
+
   genTraits();
   // Generate C++ code for various op methods. The order here determines the
   // methods in the generated file.
@@ -900,13 +909,11 @@ void OpEmitter::genVerifier() {
 
   auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
   auto &body = method.body();
-  FmtContext fctx;
-  fctx.withOp("(*this->getOperation())");
 
   // Populate substitutions for attributes and named operands and results.
   for (const auto &namedAttr : op.getAttributes())
-    fctx.addSubst(namedAttr.name,
-                  formatv("(&this->getAttr(\"{0}\"))", namedAttr.name));
+    verifyCtx.addSubst(namedAttr.name,
+                       formatv("(&this->getAttr(\"{0}\"))", namedAttr.name));
   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
     auto &value = op.getOperand(i);
     // Skip from from first variadic operands for now. Else getOperand index
@@ -914,8 +921,8 @@ void OpEmitter::genVerifier() {
     if (value.isVariadic())
       break;
     if (!value.name.empty())
-      fctx.addSubst(value.name,
-                    formatv("this->getOperation()->getOperand({0})", i));
+      verifyCtx.addSubst(value.name,
+                         formatv("this->getOperation()->getOperand({0})", i));
   }
   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
     auto &value = op.getResult(i);
@@ -924,8 +931,8 @@ void OpEmitter::genVerifier() {
     if (value.isVariadic())
       break;
     if (!value.name.empty())
-      fctx.addSubst(value.name,
-                    formatv("this->getOperation()->getResult({0})", i));
+      verifyCtx.addSubst(value.name,
+                         formatv("this->getOperation()->getResult({0})", i));
   }
 
   // Verify the attributes have the correct type.
@@ -955,11 +962,12 @@ void OpEmitter::genVerifier() {
 
     auto attrPred = attr.getPredicate();
     if (!attrPred.isNull()) {
-      body << tgfmt("    if (!($0)) return emitOpError(\"attribute '$1' "
-                    "failed to satisfy constraint: $2\");\n",
-                    /*ctx=*/nullptr,
-                    tgfmt(attrPred.getCondition(), &fctx.withSelf(varName)),
-                    attrName, attr.getDescription());
+      body << tgfmt(
+          "    if (!($0)) return emitOpError(\"attribute '$1' "
+          "failed to satisfy constraint: $2\");\n",
+          /*ctx=*/nullptr,
+          tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)),
+          attrName, attr.getDescription());
     }
 
     body << "  }\n";
@@ -977,10 +985,11 @@ void OpEmitter::genVerifier() {
     if (value.hasPredicate()) {
       auto description = value.constraint.getDescription();
       body << "  if (!("
-           << tgfmt(value.constraint.getConditionTemplate(),
-                    &fctx.withSelf("this->getOperation()->get" +
-                                   Twine(isOperand ? "Operand" : "Result") +
-                                   "(" + Twine(index) + ")->getType()"))
+           << tgfmt(
+                  value.constraint.getConditionTemplate(),
+                  &verifyCtx.withSelf("this->getOperation()->get" +
+                                      Twine(isOperand ? "Operand" : "Result") +
+                                      "(" + Twine(index) + ")->getType()"))
            << ")) {\n";
       body << "    return emitOpError(\"" << (isOperand ? "operand" : "result")
            << " #" << index
@@ -1000,19 +1009,14 @@ void OpEmitter::genVerifier() {
 
   for (auto &trait : op.getTraits()) {
     if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
-      body << tgfmt("  if (!($0))\n    return emitOpError(\""
-                    "failed to verify that $1\");\n",
-                    &fctx, tgfmt(t->getPredTemplate(), &fctx),
+      body << tgfmt("  if (!($0)) {\n    "
+                    "return emitOpError(\"failed to verify that $1\");\n  }\n",
+                    &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
                     t->getDescription());
     }
   }
 
-  // Verify this op has the correct number of regions
-  body << formatv(
-      "  if (this->getOperation()->getNumRegions() != {0}) \n    return "
-      "emitOpError(\"has incorrect number of regions: expected {0} but found "
-      "\") << this->getOperation()->getNumRegions();\n",
-      op.getNumRegions());
+  genRegionVerifier(body);
 
   if (hasCustomVerify)
     body << codeInit->getValue() << "\n";
@@ -1020,6 +1024,36 @@ void OpEmitter::genVerifier() {
     body << "  return mlir::success();\n";
 }
 
+void OpEmitter::genRegionVerifier(OpMethodBody &body) {
+  unsigned numRegions = op.getNumRegions();
+
+  // Verify this op has the correct number of regions
+  body << formatv(
+      "  if (this->getOperation()->getNumRegions() != {0}) {\n    "
+      "return emitOpError(\"has incorrect number of regions: expected {0} but "
+      "found \") << this->getOperation()->getNumRegions();\n  }\n",
+      numRegions);
+
+  for (unsigned i = 0; i < numRegions; ++i) {
+    const auto &region = op.getRegion(i);
+
+    std::string name = formatv("#{0}", i);
+    if (!region.name.empty()) {
+      name += formatv(" ('{0}')", region.name);
+    }
+
+    auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str();
+    auto constraint = tgfmt(region.constraint.getConditionTemplate(),
+                            &verifyCtx.withSelf(getRegion))
+                          .str();
+
+    body << formatv("  if (!({0})) {\n    "
+                    "return emitOpError(\"region {1} failed to verify "
+                    "constraint: {2}\");\n  }\n",
+                    constraint, name, region.constraint.getDescription());
+  }
+}
+
 void OpEmitter::genTraits() {
   int numResults = op.getNumResults();
   int numVariadicResults = op.getNumVariadicResults();