Validate the names of attribute, dialect, and functions during verification. This...
authorRiver Riddle <riverriddle@google.com>
Wed, 27 Feb 2019 00:43:12 +0000 (16:43 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 23:44:53 +0000 (16:44 -0700)
PiperOrigin-RevId: 235818842

mlir/g3doc/LangRef.md
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/Types.h
mlir/lib/Analysis/Verifier.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/Types.cpp
mlir/lib/Parser/Lexer.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/invalid.mlir

index e6008f20faf352f1c158de277ef2b869394457d6..07a422fdfe025d976d85154d49a40c8dbe1b4562 100644 (file)
@@ -173,7 +173,7 @@ Syntax:
 
 ``` {.ebnf}
 // Identifiers
-bare-id ::= letter (letter|digit|[_])*
+bare-id ::= (letter|[_]) (letter|digit|[_$.])*
 bare-id-list ::= bare-id (`,` bare-id)*
 suffix-id ::= digit+ | ((letter|id-punct) (letter|id-punct|digit)*)
 
index 067fe53dad3820d8cfa79191c3c02c5388bbac84..1024c139ad52e5f83fb999cd151edab159d82166 100644 (file)
@@ -103,6 +103,10 @@ public:
 
   virtual ~Dialect();
 
+  /// Utility function that returns if the given string is a valid dialect
+  /// namespace.
+  static bool isValidNamespace(StringRef str);
+
 protected:
   /// Note: The namePrefix can be empty, but it must not contain '.' characters.
   /// Note: If the name is non empty, then all operations belonging to this
index 527ce4a80825d34c3879114c11c86688d5c46667..18afba029603c75c068fc104342fec94f6cbe438 100644 (file)
@@ -289,12 +289,24 @@ public:
   static UnknownType get(Identifier dialect, StringRef typeData,
                          MLIRContext *context);
 
+  /// Get or create a new UnknownType with the provided dialect and string data.
+  /// If the given identifier is not a valid namespace for a dialect, then a
+  /// null type is returned.
+  static UnknownType getChecked(Identifier dialect, StringRef typeData,
+                                MLIRContext *context, Location location);
+
   /// Returns the dialect namespace of the unknown type.
   Identifier getDialectNamespace() const;
 
   /// Returns the raw type data of the unknown type.
   StringRef getTypeData() const;
 
+  /// Verify the construction of an unknown type.
+  static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
+                                           MLIRContext *context,
+                                           Identifier dialect,
+                                           StringRef typeData);
+
   static bool kindof(unsigned kind) { return kind == Kind::Unknown; }
 
   /// Unique identifier for this type class.
index fbdf178769552629c8b1fd3d749d78e4508a4bd7..6abc467cffb5f13403c94de5dcabce790cf467be 100644 (file)
@@ -39,6 +39,7 @@
 #include "mlir/IR/Instruction.h"
 #include "mlir/IR/Module.h"
 #include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/Regex.h"
 #include "llvm/Support/raw_ostream.h"
 using namespace mlir;
 
@@ -66,7 +67,33 @@ public:
     return failure(message, fn);
   }
 
-  bool verifyAttribute(Attribute attr, const Instruction &op);
+  template <typename ErrorContext>
+  bool verifyAttribute(Attribute attr, const ErrorContext &ctx) {
+    if (!attr.isOrContainsFunction())
+      return false;
+
+    // If we have a function attribute, check that it is non-null and in the
+    // same module as the operation that refers to it.
+    if (auto fnAttr = attr.dyn_cast<FunctionAttr>()) {
+      if (!fnAttr.getValue())
+        return failure("attribute refers to deallocated function!", ctx);
+
+      if (fnAttr.getValue()->getModule() != fn.getModule())
+        return failure("attribute refers to function '" +
+                           Twine(fnAttr.getValue()->getName()) +
+                           "' defined in another module!",
+                       ctx);
+      return false;
+    }
+
+    // Otherwise, we must have an array attribute, remap the elements.
+    for (auto elt : attr.cast<ArrayAttr>().getValue()) {
+      if (verifyAttribute(elt, ctx))
+        return true;
+    }
+
+    return false;
+  }
 
   bool verify();
   bool verifyBlock(const Block &block, bool isTopLevel);
@@ -74,7 +101,8 @@ public:
   bool verifyDominance(const Block &block);
   bool verifyInstDominance(const Instruction &inst);
 
-  explicit FuncVerifier(const Function &fn) : fn(fn) {}
+  explicit FuncVerifier(const Function &fn)
+      : fn(fn), attrNameRegex("^:?[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {}
 
 private:
   /// The function being checked.
@@ -82,6 +110,9 @@ private:
 
   /// Dominance information for this function, when checking dominance.
   DominanceInfo *domInfo = nullptr;
+
+  /// Regex checker for attribute names.
+  llvm::Regex attrNameRegex;
 };
 } // end anonymous namespace
 
@@ -93,11 +124,25 @@ bool FuncVerifier::verify() {
   if (fn.isExternal())
     return false;
 
+  // Check that the function name is valid.
+  llvm::Regex funcNameRegex("^[a-zA-Z][a-zA-Z_0-9\\.\\$]*$");
+  if (!funcNameRegex.match(fn.getName().strref()))
+    return failure("invalid function name '" + fn.getName().strref() + "'", fn);
+
   // Verify the first block has no predecessors.
   auto *firstBB = &fn.front();
   if (!firstBB->hasNoPredecessors())
     return failure("entry block of function may not have predecessors", fn);
 
+  /// Verify that all of the attributes are okay.
+  for (auto attr : fn.getAttrs()) {
+    if (!attrNameRegex.match(attr.first))
+      return failure("invalid attribute name '" + attr.first.strref() + "'",
+                     fn);
+    if (verifyAttribute(attr.second, fn))
+      return true;
+  }
+
   // Verify that the argument list of the function and the arg list of the first
   // block line up.
   auto fnInputTypes = fn.getType().getInputs();
@@ -133,34 +178,6 @@ bool FuncVerifier::verify() {
   return false;
 }
 
-// Check that function attributes are all well formed.
-bool FuncVerifier::verifyAttribute(Attribute attr, const Instruction &op) {
-  if (!attr.isOrContainsFunction())
-    return false;
-
-  // If we have a function attribute, check that it is non-null and in the
-  // same module as the operation that refers to it.
-  if (auto fnAttr = attr.dyn_cast<FunctionAttr>()) {
-    if (!fnAttr.getValue())
-      return failure("attribute refers to deallocated function!", op);
-
-    if (fnAttr.getValue()->getModule() != fn.getModule())
-      return failure("attribute refers to function '" +
-                         Twine(fnAttr.getValue()->getName()) +
-                         "' defined in another module!",
-                     op);
-    return false;
-  }
-
-  // Otherwise, we must have an array attribute, remap the elements.
-  for (auto elt : attr.cast<ArrayAttr>().getValue()) {
-    if (verifyAttribute(elt, op))
-      return true;
-  }
-
-  return false;
-}
-
 // Returns if the given block is allowed to have no terminator.
 static bool canBlockHaveNoTerminator(const Block &block) {
   // Allow the first block of an operation region to have no terminator if it is
@@ -224,10 +241,11 @@ bool FuncVerifier::verifyOperation(const Instruction &op) {
       return failure("reference to operand defined in another function", op);
   }
 
-  // Verify all attributes are ok.  We need to check Function attributes, since
-  // they are actually mutable (the function they refer to can be deleted), and
-  // we have to check array attributes that can refer to them.
+  /// Verify that all of the attributes are okay.
   for (auto attr : op.getAttrs()) {
+    if (!attrNameRegex.match(attr.first))
+      return failure("invalid attribute name '" + attr.first.strref() + "'",
+                     op);
     if (verifyAttribute(attr.second, op))
       return true;
   }
index c24d6b1f388d4406814dea2d689ec164246d6efc..d19b150fb48b25343968c171b6833a92976e9766 100644 (file)
@@ -20,6 +20,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/ManagedStatic.h"
+#include "llvm/Support/Regex.h"
 using namespace mlir;
 
 // Registry for all dialect allocation functions.
@@ -60,8 +61,7 @@ void mlir::registerAllDialects(MLIRContext *context) {
 
 Dialect::Dialect(StringRef namePrefix, MLIRContext *context)
     : namePrefix(namePrefix), context(context) {
-  assert(!namePrefix.contains('.') &&
-         "Dialect names cannot contain '.' characters.");
+  assert(isValidNamespace(namePrefix) && "invalid dialect namespace");
   registerDialect(context);
 }
 
@@ -74,3 +74,12 @@ Type Dialect::parseType(StringRef tyData, Location loc,
                               "' provides no type parsing hook");
   return Type();
 }
+
+/// Utility function that returns if the given string is a valid dialect
+/// namespace.
+bool Dialect::isValidNamespace(StringRef str) {
+  if (str.empty())
+    return true;
+  llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
+  return dialectNameRegex.match(str);
+}
index fa505b1366793d6a9f4df4557ee26d16c2e44ef3..c42143134b58e7d6526296fedb64233c24f6baaa 100644 (file)
@@ -18,6 +18,7 @@
 #include "mlir/IR/Types.h"
 #include "TypeDetail.h"
 #include "mlir/IR/Dialect.h"
+#include "llvm/ADT/Twine.h"
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -58,6 +59,11 @@ UnknownType UnknownType::get(Identifier dialect, StringRef typeData,
   return Base::get(context, Type::Kind::Unknown, dialect, typeData);
 }
 
+UnknownType UnknownType::getChecked(Identifier dialect, StringRef typeData,
+                                    MLIRContext *context, Location location) {
+  return Base::getChecked(location, context, Kind::Unknown, dialect, typeData);
+}
+
 /// Returns the dialect namespace of the unknown type.
 Identifier UnknownType::getDialectNamespace() const {
   return static_cast<ImplType *>(type)->dialectNamespace;
@@ -68,6 +74,20 @@ StringRef UnknownType::getTypeData() const {
   return static_cast<ImplType *>(type)->typeData;
 }
 
+/// Verify the construction of an unknown type.
+bool UnknownType::verifyConstructionInvariants(llvm::Optional<Location> loc,
+                                               MLIRContext *context,
+                                               Identifier dialect,
+                                               StringRef typeData) {
+  if (!Dialect::isValidNamespace(dialect.strref())) {
+    if (loc)
+      context->emitError(*loc, "invalid dialect namespace '" +
+                                   dialect.strref() + "'");
+    return true;
+  }
+  return false;
+}
+
 // Define type identifiers.
 char FunctionType::typeID = 0;
 char UnknownType::typeID = 0;
index 09532563661c71bd3646f8568b01fe26724fe97c..71b04140be166719a693da72ab478504ef77c522 100644 (file)
@@ -221,7 +221,8 @@ Token Lexer::lexAtIdentifier(const char *tokStart) {
   if (!isalpha(*curPtr++))
     return emitError(curPtr-1, "expected letter in @ identifier");
 
-  while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_')
+  while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
+         *curPtr == '$' || *curPtr == '.')
     ++curPtr;
   return formToken(Token::at_identifier, tokStart);
 }
index f89053599947fa18358322031bfd376acd616db7..6675490ad9f56d6a2dffb6e6e706c729b5d39f51 100644 (file)
@@ -520,8 +520,8 @@ Type Parser::parseExtendedType() {
       return nullptr;
   } else {
     // Otherwise, form a new unknown type.
-    result = UnknownType::get(Identifier::get(identifier, state.context),
-                              typeData, state.context);
+    result = UnknownType::getChecked(Identifier::get(identifier, state.context),
+                                     typeData, state.context, loc);
   }
 
   // Consume the '>'.
index fe0404efc62abe6eda29676ff6715fc632ce21c3..803aecda5f4e9f9da268bc87a7d9e0c643cde929 100644 (file)
@@ -908,3 +908,8 @@ func @invalid_nested_dominance() {
   }
   return
 }
+
+// -----
+
+// expected-error @+1 {{invalid dialect namespace 'invalid.dialect'}}
+func @invalid_unknown_type_dialect_name() -> !invalid.dialect<"">