[ODS] Support numRegions in Op definition
authorLei Zhang <antiagainst@google.com>
Tue, 28 May 2019 15:03:46 +0000 (08:03 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:05:31 +0000 (20:05 -0700)
--

PiperOrigin-RevId: 250282024

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Operator.cpp
mlir/test/IR/region.mlir [new file with mode: 0644]
mlir/test/TestDialect/TestOps.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index 8ef6e0b..edbd273 100644 (file)
@@ -916,6 +916,10 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
   // The list of results of the op. Default to 0 results.
   dag results = (outs);
 
+  // How many regions this op has.
+  // TODO(b/133479568): Enhance to support advanced region usage cases
+  int numRegions = 0;
+
   // Attribute getters can be added to the op by adding an Attr member
   // with the name and type of the attribute. E.g., adding int attribute
   // with name "value" and type "i32":
index 3f52847..77b3a9f 100644 (file)
@@ -129,6 +129,9 @@ public:
   // requiring the raw MLIR trait here.
   bool hasTrait(llvm::StringRef trait) const;
 
+  // Returns the number of regions.
+  int getNumRegions() const;
+
   // Trait.
   using const_trait_iterator = const OpTrait *;
   const_trait_iterator trait_begin() const;
@@ -174,6 +177,9 @@ private:
   // The traits of the op.
   SmallVector<OpTrait, 4> traits;
 
+  // The number of regions of this op.
+  int numRegions = 0;
+
   // The number of native attributes stored in the leading positions of
   // `attributes`.
   int numNativeAttributes;
index 2223974..d27db0d 100644 (file)
@@ -146,6 +146,8 @@ bool tblgen::Operator::hasTrait(StringRef trait) const {
   return false;
 }
 
+int tblgen::Operator::getNumRegions() const { return numRegions; }
+
 auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
   return traits.begin();
 }
@@ -265,6 +267,11 @@ void tblgen::Operator::populateOpStructure() {
   traits.reserve(traitListInit->size());
   for (auto traitInit : *traitListInit)
     traits.push_back(OpTrait::create(traitInit));
+
+  // Handle regions
+  numRegions = def.getValueAsInt("numRegions");
+  if (numRegions < 0)
+    PrintFatalError(def.getLoc(), "numRegions cannot be negative");
 }
 
 ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir
new file mode 100644 (file)
index 0000000..702e56d
--- /dev/null
@@ -0,0 +1,32 @@
+// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
+
+func @correct_number_of_regions() {
+    // CHECK: test.two_region_op
+    "test.two_region_op"()(
+      {"work"() : () -> ()},
+      {"work"() : () -> ()}
+    ) : () -> ()
+    return
+}
+
+// -----
+
+func @missingk_regions() {
+    // expected-error@+1 {{op has incorrect number of regions: expected 2 but found 1}}
+    "test.two_region_op"()(
+      {"work"() : () -> ()}
+    ) : () -> ()
+    return
+}
+
+// -----
+
+func @extra_regions() {
+    // expected-error@+1 {{op has incorrect number of regions: expected 2 but found 3}}
+    "test.two_region_op"()(
+      {"work"() : () -> ()},
+      {"work"() : () -> ()},
+      {"work"() : () -> ()}
+    ) : () -> ()
+    return
+}
index 3c0ade3..915318d 100644 (file)
@@ -113,4 +113,12 @@ def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>;
 def : Pat<(OpG $input), (OpB $input, ConstantAttr<I32Attr, "20">:$attr)>;
 def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
 
+//===----------------------------------------------------------------------===//
+// Test op regions
+//===----------------------------------------------------------------------===//
+
+def TwoRegionOp : TEST_Op<"two_region_op", []> {
+  let numRegions = 2;
+}
+
 #endif // TEST_OPS
index 3d5a3b9..ca8b27a 100644 (file)
@@ -742,6 +742,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
       }
     }
   }
+
+  // Create the correct number of regions
+  if (int numRegions = op.getNumRegions()) {
+    for (int i = 0; i < numRegions; ++i)
+      m.body() << "  (void)" << builderOpState << "->addRegion();\n";
+  }
 }
 
 void OpEmitter::genBuilder() {
@@ -820,6 +826,12 @@ void OpEmitter::genBuilder() {
        << "    " << builderOpState
        << "->addAttribute(pair.first, pair.second);\n";
 
+  // Create the correct number of regions
+  if (int numRegions = op.getNumRegions()) {
+    for (int i = 0; i < numRegions; ++i)
+      m.body() << "  (void)" << builderOpState << "->addRegion();\n";
+  }
+
   // 3. Deduced result types
 
   bool useOperandType = op.hasTrait("SameOperandsAndResultType");
@@ -883,9 +895,6 @@ void OpEmitter::genVerifier() {
   auto valueInit = def.getValueInit("verifier");
   CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
   bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
-  if (!hasCustomVerify && op.getNumArgs() == 0 && op.getNumResults() == 0 &&
-      op.getNumPredOpTraits() == 0)
-    return;
 
   auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
   auto &body = method.body();
@@ -972,6 +981,13 @@ void OpEmitter::genVerifier() {
     }
   }
 
+  // Verify this op has the correct number of regions
+  body << formatv(
+      "  if (this->getOperation()->getNumRegions() != {0}) \n    return "
+      "emitOpError(\"has incorrect number of regions: expected {0} but found "
+      "\") << this->getOperation()->getNumRegions();\n",
+      op.getNumRegions());
+
   if (hasCustomVerify)
     body << codeInit->getValue() << "\n";
   else