Restructure the parser to support nested name scopes. This allows for regions at...
authorRiver Riddle <riverriddle@google.com>
Mon, 3 Jun 2019 16:43:22 +0000 (09:43 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 4 Jun 2019 02:26:20 +0000 (19:26 -0700)
The following is now valid IR:
  foo.op ... {
    %val = ...
  }, {
    %val = ...
  }

PiperOrigin-RevId: 251249875

mlir/lib/Parser/Parser.cpp
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir

index 3f4bf04..87f94f7 100644 (file)
@@ -35,6 +35,7 @@
 #include "mlir/Transforms/Utils.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringSet.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/PrettyStackTrace.h"
 #include "llvm/Support/SMLoc.h"
@@ -2309,8 +2310,6 @@ public:
 
   ~FunctionParser();
 
-  ParseResult parseFunctionBody(bool hadNamedArguments);
-
   /// Parse a single operation successor and it's operand list.
   ParseResult parseSuccessorAndUseList(Block *&dest,
                                        SmallVectorImpl<Value *> &operands);
@@ -2360,11 +2359,16 @@ public:
   ParseResult
   parseOptionalSSAUseAndTypeList(SmallVectorImpl<ValueTy *> &results);
 
+  /// Push a new SSA name scope to the parser.
+  void pushSSANameScope();
+
+  /// Pop the last SSA name scope from the parser.
+  ParseResult popSSANameScope();
+
   // Block references.
 
-  ParseResult
-  parseOperationRegion(Region &region,
-                       ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments);
+  ParseResult parseRegion(Region &region,
+                          ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments);
   ParseResult parseRegionBody(Region &region);
   ParseResult parseBlock(Block *&block);
   ParseResult parseBlockBody(Block *block);
@@ -2411,15 +2415,37 @@ public:
 private:
   Function *function;
 
+  /// Returns the info for a block at the current scope for the given name.
+  std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) {
+    return blocksByName.back()[name];
+  }
+
+  /// Insert a new forward reference to the given block.
+  void insertForwardRef(Block *block, SMLoc loc) {
+    forwardRef.back().try_emplace(block, loc);
+  }
+
+  /// Erase any forward reference to the given block.
+  bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); }
+
+  /// Record that a definition was added at the current scope.
+  void recordDefinition(StringRef def) {
+    definitionsPerScope.back().insert(def);
+  }
+
   // This keeps track of the block names as well as the location of the first
-  // reference, used to diagnose invalid block references and memoize them.
-  llvm::StringMap<std::pair<Block *, SMLoc>> blocksByName;
-  DenseMap<Block *, SMLoc> forwardRef;
+  // reference for each nested name scope. This is used to diagnose invalid
+  // block references and memoize them.
+  SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName;
+  SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef;
 
-  /// This keeps track of all of the SSA values we are tracking, indexed by
-  /// their name.  This has one entry per result number.
+  /// This keeps track of all of the SSA values we are tracking for each name
+  /// scope, indexed by their name. This has one entry per result number.
   llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values;
 
+  /// This keeps track of all of the values defined by a specific name scope.
+  SmallVector<llvm::StringSet<>, 2> definitionsPerScope;
+
   /// These are all of the placeholders we've made along with the location of
   /// their first reference, to allow checking for use of undefined values.
   DenseMap<Value *, SMLoc> forwardReferencePlaceholders;
@@ -2433,37 +2459,20 @@ private:
 };
 } // end anonymous namespace
 
-ParseResult FunctionParser::parseFunctionBody(bool hadNamedArguments) {
-  auto braceLoc = getToken().getLoc();
-  if (parseToken(Token::l_brace, "expected '{' in function"))
-    return failure();
-
-  // Make sure we have at least one block.
-  if (getToken().is(Token::r_brace))
-    return emitError("function must have a body");
-
-  // If we had named arguments, then we don't allow a block name.
-  if (hadNamedArguments) {
-    if (getToken().is(Token::caret_identifier))
-      return emitError("invalid block name in function with named arguments");
-  }
-
-  // The first block is already created and should be filled in.
-  auto firstBlock = &function->front();
-
-  // Parse the first block.
-  if (parseBlock(firstBlock))
-    return failure();
+void FunctionParser::pushSSANameScope() {
+  blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>());
+  forwardRef.push_back(DenseMap<Block *, SMLoc>());
+  definitionsPerScope.push_back({});
+}
 
-  // Parse the remaining list of blocks.
-  if (parseRegionBody(function->getBody()))
-    return failure();
+ParseResult FunctionParser::popSSANameScope() {
+  auto forwardRefInCurrentScope = forwardRef.pop_back_val();
 
   // Verify that all referenced blocks were defined.
-  if (!forwardRef.empty()) {
+  if (!forwardRefInCurrentScope.empty()) {
     SmallVector<std::pair<const char *, Block *>, 4> errors;
     // Iteration over the map isn't deterministic, so sort by source location.
-    for (auto entry : forwardRef) {
+    for (auto entry : forwardRefInCurrentScope) {
       errors.push_back({entry.second.getPointer(), entry.first});
       cleanupInvalidBlocks(entry.first);
     }
@@ -2476,14 +2485,19 @@ ParseResult FunctionParser::parseFunctionBody(bool hadNamedArguments) {
     return failure();
   }
 
-  return finalizeFunction(braceLoc);
+  // Drop any values defined in this scope from the value map.
+  for (auto &def : definitionsPerScope.pop_back_val())
+    values.erase(def.getKey());
+  blocksByName.pop_back();
+
+  return success();
 }
 
 /// Block list.
 ///
 ///   block-list ::= '{' block-list-body
 ///
-ParseResult FunctionParser::parseOperationRegion(
+ParseResult FunctionParser::parseRegion(
     Region &region,
     ArrayRef<std::pair<FunctionParser::SSAUseInfo, Type>> entryArguments) {
   // Parse the '{'.
@@ -2493,18 +2507,27 @@ ParseResult FunctionParser::parseOperationRegion(
   // Check for an empty region.
   if (entryArguments.empty() && consumeIf(Token::r_brace))
     return success();
-  Block *currentBlock = builder.getInsertionBlock();
+  auto currentPt = builder.saveInsertionPoint();
+
+  // Push a new named value scope.
+  pushSSANameScope();
 
   // Parse the first block directly to allow for it to be unnamed.
   Block *block = new Block();
 
   // Add arguments to the entry block.
-  for (auto &placeholderArgPair : entryArguments)
-    if (addDefinition(placeholderArgPair.first,
-                      block->addArgument(placeholderArgPair.second))) {
-      delete block;
-      return failure();
-    }
+  if (!entryArguments.empty()) {
+    for (auto &placeholderArgPair : entryArguments)
+      if (addDefinition(placeholderArgPair.first,
+                        block->addArgument(placeholderArgPair.second))) {
+        delete block;
+        return failure();
+      }
+
+    // If we had named arguments, then don't allow a block name.
+    if (getToken().is(Token::caret_identifier))
+      return emitError("invalid block name in region with named arguments");
+  }
 
   if (parseBlock(block)) {
     delete block;
@@ -2523,8 +2546,12 @@ ParseResult FunctionParser::parseOperationRegion(
   if (parseRegionBody(region))
     return failure();
 
-  // Reset insertion point to the current block.
-  builder.setInsertionPointToEnd(currentBlock);
+  // Pop the SSA value scope.
+  if (popSSANameScope())
+    return failure();
+
+  // Reset the original insertion point.
+  builder.restoreInsertionPoint(currentPt);
   return success();
 }
 
@@ -2709,8 +2736,9 @@ ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, Value *value) {
     forwardReferencePlaceholders.erase(existing);
   }
 
-  entries[useInfo.number].first = value;
-  entries[useInfo.number].second = useInfo.loc;
+  /// Record this definition for the current scope.
+  entries[useInfo.number] = {value, useInfo.loc};
+  recordDefinition(useInfo.name);
   return success();
 }
 
@@ -2761,7 +2789,6 @@ FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
 template <typename ResultType>
 ResultType FunctionParser::parseSSADefOrUseAndType(
     const std::function<ResultType(SSAUseInfo, Type)> &action) {
-
   SSAUseInfo useInfo;
   if (parseSSAUse(useInfo) ||
       parseToken(Token::colon, "expected ':' and type for SSA operand"))
@@ -2815,11 +2842,10 @@ ParseResult FunctionParser::parseOptionalSSAUseAndTypeList(
 /// exist.  The location specified is the point of use, which allows
 /// us to diagnose references to blocks that are not defined precisely.
 Block *FunctionParser::getBlockNamed(StringRef name, SMLoc loc) {
-  auto &blockAndLoc = blocksByName[name];
+  auto &blockAndLoc = getBlockInfoByName(name);
   if (!blockAndLoc.first) {
-    blockAndLoc.first = new Block();
-    forwardRef[blockAndLoc.first] = loc;
-    blockAndLoc.second = loc;
+    blockAndLoc = {new Block(), loc};
+    insertForwardRef(blockAndLoc.first, loc);
   }
 
   return blockAndLoc.first;
@@ -2829,7 +2855,7 @@ Block *FunctionParser::getBlockNamed(StringRef name, SMLoc loc) {
 /// the case of redefinition.
 Block *FunctionParser::defineBlockNamed(StringRef name, SMLoc loc,
                                         Block *existing) {
-  auto &blockAndLoc = blocksByName[name];
+  auto &blockAndLoc = getBlockInfoByName(name);
   if (!blockAndLoc.first) {
     // If the caller provided a block, use it.  Otherwise create a new one.
     if (!existing)
@@ -2842,7 +2868,7 @@ Block *FunctionParser::defineBlockNamed(StringRef name, SMLoc loc,
   // Forward declarations are removed once defined, so if we are defining a
   // existing block and it is not a forward declaration, then it is a
   // redeclaration.
-  if (!forwardRef.erase(blockAndLoc.first))
+  if (!eraseForwardRef(blockAndLoc.first))
     return nullptr;
   return blockAndLoc.first;
 }
@@ -3090,8 +3116,8 @@ Operation *FunctionParser::parseGenericOperation() {
     do {
       // Create temporary regions with function as parent.
       result.regions.emplace_back(new Region(function));
-      if (parseOperationRegion(*result.regions.back(),
-                               /*entryArguments*/ {}))
+      if (parseRegion(*result.regions.back(),
+                      /*entryArguments*/ {}))
         return nullptr;
     } while (consumeIf(Token::comma));
     if (parseToken(Token::r_paren, "expected ')' to end region list"))
@@ -3382,7 +3408,7 @@ public:
       parsedRegionEntryArgumentPlaceholders.emplace_back(value);
     }
 
-    return parser.parseOperationRegion(region, regionArguments);
+    return parser.parseRegion(region, regionArguments);
   }
 
   /// Parse a region argument.  Region arguments define new values, so this also
@@ -3506,10 +3532,11 @@ private:
   // Functions.
   ParseResult
   parseArgumentList(SmallVectorImpl<Type> &argTypes,
-                    SmallVectorImpl<StringRef> &argNames,
+                    SmallVectorImpl<std::pair<SMLoc, StringRef>> &argNames,
                     SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
   ParseResult parseFunctionSignature(
-      StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> &argNames,
+      StringRef &name, FunctionType &type,
+      SmallVectorImpl<std::pair<SMLoc, StringRef>> &argNames,
       SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
   ParseResult parseFunc();
 };
@@ -3589,7 +3616,8 @@ ParseResult ModuleParser::parseTypeAliasDef() {
 ///                     | /*empty*/
 ///
 ParseResult ModuleParser::parseArgumentList(
-    SmallVectorImpl<Type> &argTypes, SmallVectorImpl<StringRef> &argNames,
+    SmallVectorImpl<Type> &argTypes,
+    SmallVectorImpl<std::pair<SMLoc, StringRef>> &argNames,
     SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
   consumeToken(Token::l_paren);
 
@@ -3605,7 +3633,7 @@ ParseResult ModuleParser::parseArgumentList(
       if (argNames.empty() && !argTypes.empty())
         return emitError(loc, "expected type instead of SSA identifier");
 
-      argNames.push_back(name);
+      argNames.emplace_back(loc, name);
 
       if (parseToken(Token::colon, "expected ':'"))
         return failure();
@@ -3641,7 +3669,8 @@ ParseResult ModuleParser::parseArgumentList(
 ///      function-id `(` argument-list `)` (`->` type-list)?
 ///
 ParseResult ModuleParser::parseFunctionSignature(
-    StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> &argNames,
+    StringRef &name, FunctionType &type,
+    SmallVectorImpl<std::pair<SMLoc, StringRef>> &argNames,
     SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
   if (getToken().isNot(Token::at_identifier))
     return emitError("expected a function identifier like '@foo'");
@@ -3678,7 +3707,7 @@ ParseResult ModuleParser::parseFunc() {
 
   StringRef name;
   FunctionType type;
-  SmallVector<StringRef, 4> argNames;
+  SmallVector<std::pair<SMLoc, StringRef>, 4> argNames;
   SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
 
   auto loc = getToken().getLoc();
@@ -3712,24 +3741,26 @@ ParseResult ModuleParser::parseFunc() {
   // External functions have no body.
   if (getToken().isNot(Token::l_brace))
     return success();
+  auto braceLoc = getToken().getLoc();
 
-  // Create the parser.
-  auto parser = FunctionParser(getState(), function);
-
-  bool hadNamedArguments = !argNames.empty();
+  // Prepare the named function arguments.
+  SmallVector<std::pair<FunctionParser::SSAUseInfo, Type>, 4> entryArgs;
+  for (unsigned i = 0, e = argNames.size(); i != e; ++i) {
+    entryArgs.emplace_back(
+        FunctionParser::SSAUseInfo{argNames[i].second, 0, argNames[i].first},
+        type.getInput(i));
+  }
 
-  // Add the entry block and argument list.
-  function->addEntryBlock();
+  // Parse the function body.
+  auto parser = FunctionParser(getState(), function);
+  if (parser.parseRegion(function->getBody(), entryArgs))
+    return failure();
 
-  // Add definitions of the function arguments.
-  if (hadNamedArguments) {
-    for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
-      if (parser.addDefinition({argNames[i], 0, loc}, function->getArgument(i)))
-        return failure();
-    }
-  }
+  // Verify that a valid function body was parsed.
+  if (function->empty())
+    return emitError(braceLoc, "function must have a body");
 
-  return parser.parseFunctionBody(hadNamedArguments);
+  return parser.finalizeFunction(braceLoc);
 }
 
 /// This is the top-level module parser.
index 1e7fc3d..a7e7035 100644 (file)
@@ -143,8 +143,8 @@ func @block_first_has_predecessor() {
 
 // -----
 
-func @empty() {
-} // expected-error {{function must have a body}}
+func @empty() { // expected-error {{function must have a body}}
+}
 
 // -----
 
@@ -344,7 +344,7 @@ func @argError() {
 // -----
 
 func @bbargMismatch(i32, f32) {
-// expected-error @+1 {{argument and block argument type mismatch}}
+// expected-error @-1 {{first block of function must have 2 arguments to match function signature}}
 ^bb42(%0: f32):
   return
 }
@@ -429,10 +429,10 @@ func @duplicate_induction_var() {
 
 // -----
 
-func @dominance_failure() {
+func @name_scope_failure() {
   affine.for %i = 1 to 10 {
   }
-  "xxx"(%i) : (index)->()   // expected-error {{operand #0 does not dominate this use}}
+  "xxx"(%i) : (index)->()   // expected-error {{use of undeclared SSA value name}}
   return
 }
 
@@ -691,7 +691,7 @@ func @elementsattr_malformed_opaque3() -> () {
 // -----
 
 func @redundant_signature(%a : i32) -> () {
-^bb0(%b : i32):  // expected-error {{invalid block name in function with named arguments}}
+^bb0(%b : i32):  // expected-error {{invalid block name in region with named arguments}}
   return
 }
 
@@ -1053,3 +1053,17 @@ func @bad_complex(complex<i32)
 
 // expected-error @+1 {{attribute names with a '.' are reserved for dialect-defined names}}
 #foo.attr = i32
+
+// -----
+
+func @invalid_region_dominance() {
+  "foo.region"() ({
+    // expected-error @+1 {{operand #0 does not dominate this use}}
+    "foo.use" (%def) : (i32) -> ()
+    "foo.yield" () : () -> ()
+  }, {
+    // expected-note @+1 {{operand defined here}}
+    %def = "foo.def" () : () -> i32
+  }) : () -> ()
+  return
+}
index 801eb00..5ceadd9 100644 (file)
@@ -758,18 +758,6 @@ func @sparsevectorattr() -> () {
   return
 }
 
-// CHECK-LABEL: func @loops_with_blockids() {
-func @loops_with_blockids() {
-^block0:
-  affine.for %i = 1 to 100 step 2 {
-  ^block1:
-    affine.for %j = 1 to 200 {
-    ^block2:
-    }
-  }
-  return
-}
-
 // CHECK-LABEL: func @unknown_dialect_type() -> !bar<""> {
 func @unknown_dialect_type() -> !bar<""> {
   // Unregistered dialect 'bar'.
@@ -930,3 +918,18 @@ func @none_type() {
   %none_val = "foo.unknown_op"() : () -> none
   return
 }
+
+// CHECK-LABEL: func @scoped_names
+func @scoped_names() {
+  // CHECK-NEXT: "foo.region_op"
+  "foo.region_op"() ({
+    // CHECK-NEXT: "foo.unknown_op"
+    %scoped_name = "foo.unknown_op"() : () -> none
+    "foo.terminator"() : () -> ()
+  }, {
+    // CHECK: "foo.unknown_op"
+    %scoped_name = "foo.unknown_op"() : () -> none
+    "foo.terminator"() : () -> ()
+  }) : () -> ()
+  return
+}