[mlir] Simplify various pieces of code now that Identifier has access to the Context...
authorRiver Riddle <riddleriver@gmail.com>
Sat, 27 Feb 2021 01:57:03 +0000 (17:57 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Sat, 27 Feb 2021 02:00:05 +0000 (18:00 -0800)
This also exposed a bug in Dialect loading where it was not correctly identifying identifiers that had the dialect namespace as a prefix.

Differential Revision: https://reviews.llvm.org/D97431

28 files changed:
mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
mlir/examples/toy/Ch5/mlir/MLIRGen.cpp
mlir/examples/toy/Ch6/mlir/MLIRGen.cpp
mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/Location.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/Location.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/Verifier.cpp
mlir/lib/Parser/DialectSymbolParser.cpp
mlir/lib/Parser/LocationParser.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Transforms/LocationSnapshot.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
mlir/unittests/IR/AttributeTest.cpp

index 8b9f9db..39b8681 100644 (file)
@@ -93,7 +93,7 @@ private:
 
   /// Helper conversion for a Toy AST location to an MLIR location.
   mlir::Location loc(Location loc) {
-    return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
+    return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
                                      loc.col);
   }
 
index 8b9f9db..39b8681 100644 (file)
@@ -93,7 +93,7 @@ private:
 
   /// Helper conversion for a Toy AST location to an MLIR location.
   mlir::Location loc(Location loc) {
-    return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
+    return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
                                      loc.col);
   }
 
index b5e6c70..3b90fc2 100644 (file)
@@ -93,7 +93,7 @@ private:
 
   /// Helper conversion for a Toy AST location to an MLIR location.
   mlir::Location loc(Location loc) {
-    return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
+    return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
                                      loc.col);
   }
 
index b5e6c70..3b90fc2 100644 (file)
@@ -93,7 +93,7 @@ private:
 
   /// Helper conversion for a Toy AST location to an MLIR location.
   mlir::Location loc(Location loc) {
-    return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
+    return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
                                      loc.col);
   }
 
index b5e6c70..3b90fc2 100644 (file)
@@ -93,7 +93,7 @@ private:
 
   /// Helper conversion for a Toy AST location to an MLIR location.
   mlir::Location loc(Location loc) {
-    return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
+    return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
                                      loc.col);
   }
 
index e69e9a2..4dca519 100644 (file)
@@ -113,7 +113,7 @@ private:
 
   /// Helper conversion for a Toy AST location to an MLIR location.
   mlir::Location loc(Location loc) {
-    return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
+    return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
                                      loc.col);
   }
 
index eaaabc8..f8b119c 100644 (file)
@@ -56,8 +56,6 @@ public:
 
   // Locations.
   Location getUnknownLoc();
-  Location getFileLineColLoc(Identifier filename, unsigned line,
-                             unsigned column);
   Location getFusedLoc(ArrayRef<Location> locs,
                        Attribute metadata = Attribute());
 
index 994308e..ab9aa8d 100644 (file)
@@ -296,8 +296,7 @@ public:
   using Base::getChecked;
 
   /// Get or create a new OpaqueAttr with the provided dialect and string data.
-  static OpaqueAttr get(MLIRContext *context, Identifier dialect,
-                        StringRef attrData, Type type);
+  static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type);
 
   /// Get or create a new OpaqueAttr with the provided dialect and string data.
   /// If the given identifier is not a valid namespace for a dialect, then a
index 4133b05..7f1bf61 100644 (file)
@@ -293,6 +293,15 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
     "Identifier":$dialectNamespace,
     StringRefParameter<"">:$typeData
   );
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "Identifier":$dialectNamespace, CArg<"StringRef", "{}">:$typeData
+    ), [{
+      return $_get(dialectNamespace.getContext(), dialectNamespace, typeData);
+    }]>
+  ];
+  let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
 }
 
index aed0c42..279ac3e 100644 (file)
@@ -129,8 +129,7 @@ public:
   using Base::Base;
 
   /// Return a uniqued FileLineCol location object.
-  static Location get(Identifier filename, unsigned line, unsigned column,
-                      MLIRContext *context);
+  static Location get(Identifier filename, unsigned line, unsigned column);
   static Location get(StringRef filename, unsigned line, unsigned column,
                       MLIRContext *context);
 
@@ -174,7 +173,7 @@ public:
   static Location get(Identifier name, Location child);
 
   /// Return a uniqued name location object with an unknown child.
-  static Location get(Identifier name, MLIRContext *context);
+  static Location get(Identifier name);
 
   /// Return the name identifier.
   Identifier getName() const;
index bf3a6ad..c0089f7 100644 (file)
@@ -491,7 +491,7 @@ def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
 class OpaqueType<string dialect, string name, string summary>
   : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
          summary, "::mlir::OpaqueType">,
-    BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
+    BuildableType<"::mlir::OpaqueType::get("
                   "$_builder.getIdentifier(\"" # dialect # "\"), \""
                   # name # "\")">;
 
index 6dcd716..4e2aef3 100644 (file)
@@ -314,7 +314,15 @@ public:
   OperationName(StringRef name, MLIRContext *context);
 
   /// Return the name of the dialect this operation is registered to.
-  StringRef getDialect() const;
+  StringRef getDialectNamespace() const;
+
+  /// Return the Dialect this operation is registered to if it is loaded in the
+  /// context, or nullptr if the dialect isn't loaded.
+  Dialect *getDialect() const {
+    if (const auto *abstractOp = getAbstractOperation())
+      return &abstractOp->dialect;
+    return representation.get<Identifier>().getDialect();
+  }
 
   /// Return the operation name with dialect name stripped, if it has one.
   StringRef stripDialect() const;
index 6d36da6..a54006d 100644 (file)
@@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) {
 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
                                 intptr_t dataLength, const char *data,
                                 MlirType type) {
-  return wrap(OpaqueAttr::get(
-      unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
-      StringRef(data, dataLength), unwrap(type)));
+  return wrap(
+      OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
+                      StringRef(data, dataLength), unwrap(type)));
 }
 
 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
index bafeccb..d1ab379 100644 (file)
@@ -29,11 +29,6 @@ Identifier Builder::getIdentifier(StringRef str) {
 
 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
 
-Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
-                                    unsigned column) {
-  return FileLineColLoc::get(filename, line, column, context);
-}
-
 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
   return FusedLoc::get(locs, metadata, context);
 }
index f3b3291..cccef6c 100644 (file)
@@ -382,9 +382,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
 // OpaqueAttr
 //===----------------------------------------------------------------------===//
 
-OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
-                           StringRef attrData, Type type) {
-  return Base::get(context, dialect, attrData, type);
+OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type) {
+  return Base::get(dialect.getContext(), dialect, attrData, type);
 }
 
 OpaqueAttr OpaqueAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
index 228bb8d..fc21152 100644 (file)
@@ -127,8 +127,8 @@ Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
 Type Dialect::parseType(DialectAsmParser &parser) const {
   // If this dialect allows unknown types, then represent this with OpaqueType.
   if (allowsUnknownTypes()) {
-    auto ns = Identifier::get(getNamespace(), getContext());
-    return OpaqueType::get(getContext(), ns, parser.getFullSymbolSpec());
+    Identifier ns = Identifier::get(getNamespace(), getContext());
+    return OpaqueType::get(ns, parser.getFullSymbolSpec());
   }
 
   parser.emitError(parser.getNameLoc())
index 151e2cf..be7b05b 100644 (file)
@@ -48,14 +48,14 @@ Location CallSiteLoc::getCaller() const { return getImpl()->caller; }
 //===----------------------------------------------------------------------===//
 
 Location FileLineColLoc::get(Identifier filename, unsigned line,
-                             unsigned column, MLIRContext *context) {
-  return Base::get(context, filename, line, column);
+                             unsigned column) {
+  return Base::get(filename.getContext(), filename, line, column);
 }
 
 Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
                              MLIRContext *context) {
   return get(Identifier::get(filename.empty() ? "-" : filename, context), line,
-             column, context);
+             column);
 }
 
 StringRef FileLineColLoc::getFilename() const { return getImpl()->filename; }
@@ -112,8 +112,8 @@ Location NameLoc::get(Identifier name, Location child) {
   return Base::get(child->getContext(), name, child);
 }
 
-Location NameLoc::get(Identifier name, MLIRContext *context) {
-  return get(name, UnknownLoc::get(context));
+Location NameLoc::get(Identifier name) {
+  return get(name, UnknownLoc::get(name.getContext()));
 }
 
 /// Return the name identifier.
index 641594e..c420412 100644 (file)
@@ -520,9 +520,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
     // Refresh all the identifiers dialect field, this catches cases where a
     // dialect may be loaded after identifier prefixed with this dialect name
     // were already created.
+    llvm::SmallString<32> dialectPrefix(dialectNamespace);
+    dialectPrefix.push_back('.');
     for (auto &identifierEntry : impl.identifiers)
-      if (!identifierEntry.second &&
-          identifierEntry.first().startswith(dialectNamespace))
+      if (identifierEntry.second.is<MLIRContext *>() &&
+          identifierEntry.first().startswith(dialectPrefix))
         identifierEntry.second = dialect.get();
 
     // Actually register the interfaces with delayed registration.
index 9349e4c..b1f2ad6 100644 (file)
@@ -35,8 +35,10 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
 }
 
 /// Return the name of the dialect this operation is registered to.
-StringRef OperationName::getDialect() const {
-  return getStringRef().split('.').first;
+StringRef OperationName::getDialectNamespace() const {
+  if (Dialect *dialect = getDialect())
+    return dialect->getNamespace();
+  return representation.get<Identifier>().strref().split('.').first;
 }
 
 /// Return the operation name with dialect name stripped, if it has one.
@@ -213,14 +215,7 @@ MLIRContext *Operation::getContext() { return location->getContext(); }
 
 /// Return the dialect this operation is associated with, or nullptr if the
 /// associated dialect is not registered.
-Dialect *Operation::getDialect() {
-  if (auto *abstractOp = getAbstractOperation())
-    return &abstractOp->dialect;
-
-  // If this operation hasn't been registered or doesn't have abstract
-  // operation, try looking up the dialect name in the context.
-  return getContext()->getLoadedDialect(getName().getDialect());
-}
+Dialect *Operation::getDialect() { return getName().getDialect(); }
 
 Region *Operation::getParentRegion() {
   return block ? block->getParent() : nullptr;
index 9b047e7..6aadab9 100644 (file)
@@ -46,13 +46,6 @@ public:
   /// Verify the given operation.
   LogicalResult verify(Operation &op);
 
-  /// Returns the registered dialect for a dialect-specific attribute.
-  Dialect *getDialectForAttribute(const NamedAttribute &attr) {
-    assert(attr.first.strref().contains('.') && "expected dialect attribute");
-    auto dialectNamePair = attr.first.strref().split('.');
-    return ctx->getLoadedDialect(dialectNamePair.first);
-  }
-
 private:
   /// Verify the given potentially nested region or block.
   LogicalResult verifyRegion(Region &region);
@@ -81,10 +74,6 @@ private:
 
   /// Dominance information for this operation, when checking dominance.
   DominanceInfo *domInfo = nullptr;
-
-  /// Mapping between dialect namespace and if that dialect supports
-  /// unregistered operations.
-  llvm::StringMap<bool> dialectAllowsUnknownOps;
 };
 } // end anonymous namespace
 
@@ -170,15 +159,14 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
   /// Verify that all of the attributes are okay.
   for (auto attr : op.getAttrs()) {
     // Check for any optional dialect specific attributes.
-    if (!attr.first.strref().contains('.'))
-      continue;
-    if (auto *dialect = getDialectForAttribute(attr))
+    if (auto *dialect = attr.first.getDialect())
       if (failed(dialect->verifyOperationAttribute(&op, attr)))
         return failure();
   }
 
   // If we can get operation info for this, check the custom hook.
-  auto *opInfo = op.getAbstractOperation();
+  OperationName opName = op.getName();
+  auto *opInfo = opName.getAbstractOperation();
   if (opInfo && failed(opInfo->verifyInvariants(&op)))
     return failure();
 
@@ -213,33 +201,21 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
     return success();
 
   // Otherwise, verify that the parent dialect allows un-registered operations.
-  auto dialectPrefix = op.getName().getDialect();
-
-  // Check for an existing answer for the operation dialect.
-  auto it = dialectAllowsUnknownOps.find(dialectPrefix);
-  if (it == dialectAllowsUnknownOps.end()) {
-    // If the operation dialect is registered, query it directly.
-    if (auto *dialect = ctx->getLoadedDialect(dialectPrefix))
-      it = dialectAllowsUnknownOps
-               .try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
-               .first;
-    // Otherwise, unregistered dialects (when allowed by the context)
-    // conservatively allow unknown operations.
-    else {
-      if (!op.getContext()->allowsUnregisteredDialects() && !op.getDialect())
-        return op.emitOpError()
-               << "created with unregistered dialect. If this is "
-                  "intended, please call allowUnregisteredDialects() on the "
-                  "MLIRContext, or use -allow-unregistered-dialect with "
-                  "mlir-opt";
-
-      it = dialectAllowsUnknownOps.try_emplace(dialectPrefix, true).first;
+  Dialect *dialect = opName.getDialect();
+  if (!dialect) {
+    if (!ctx->allowsUnregisteredDialects()) {
+      return op.emitOpError()
+             << "created with unregistered dialect. If this is "
+                "intended, please call allowUnregisteredDialects() on the "
+                "MLIRContext, or use -allow-unregistered-dialect with "
+                "mlir-opt";
     }
+    return success();
   }
 
-  if (!it->second) {
+  if (!dialect->allowsUnknownOperations()) {
     return op.emitError("unregistered operation '")
-           << op.getName() << "' found in dialect ('" << dialectPrefix
+           << op.getName() << "' found in dialect ('" << dialect->getNamespace()
            << "') that does not allow unknown operations";
   }
 
index 16efd0a..6993b8e 100644 (file)
@@ -563,7 +563,7 @@ Type Parser::parseExtendedType() {
 
         // Otherwise, form a new opaque type.
         return OpaqueType::getChecked(
-            getEncodedSourceLocation(loc), state.context,
+            getEncodedSourceLocation(loc),
             Identifier::get(dialectName, state.context), symbolData);
       });
 }
index 93977f0..977982e 100644 (file)
@@ -145,7 +145,7 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
                    "expected ')' after child location of NameLoc"))
       return failure();
   } else {
-    loc = NameLoc::get(Identifier::get(str, ctx), ctx);
+    loc = NameLoc::get(Identifier::get(str, ctx));
   }
 
   return success();
index 0f43ac2..d423c39 100644 (file)
@@ -1944,8 +1944,8 @@ Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
   auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
   if (fileName.empty())
     fileName = "<unknown>";
-  return opBuilder.getFileLineColLoc(opBuilder.getIdentifier(fileName),
-                                     debugLine->line, debugLine->col);
+  return FileLineColLoc::get(opBuilder.getIdentifier(fileName), debugLine->line,
+                             debugLine->col);
 }
 
 LogicalResult
index 0b1d929..7d4284a 100644 (file)
@@ -44,8 +44,7 @@ static void generateLocationsFromIR(raw_ostream &os, StringRef fileName,
     if (it == opToLineCol.end())
       return;
     const std::pair<unsigned, unsigned> &lineCol = it->second;
-    auto newLoc =
-        builder.getFileLineColLoc(file, lineCol.first, lineCol.second);
+    auto newLoc = FileLineColLoc::get(file, lineCol.first, lineCol.second);
 
     // If we don't have a tag, set the location directly
     if (!tagIdentifier) {
index de3c436..002843c 100644 (file)
@@ -2702,10 +2702,10 @@ auto ConversionTarget::getOpInfo(OperationName op) const
   if (it != legalOperations.end())
     return it->second;
   // Check for info for the parent dialect.
-  auto dialectIt = legalDialects.find(op.getDialect());
+  auto dialectIt = legalDialects.find(op.getDialectNamespace());
   if (dialectIt != legalDialects.end()) {
     Optional<DynamicLegalityCallbackFn> callback;
-    auto dialectFn = dialectLegalityFns.find(op.getDialect());
+    auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
     if (dialectFn != dialectLegalityFns.end())
       callback = dialectFn->second;
     return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
index 582eecd..5925e55 100644 (file)
@@ -862,8 +862,7 @@ int main(int argc, char **argv) {
     }
 
     genContext.setLoc(NameLoc::get(
-        Identifier::get(opConfig.metadata->cppOpName, &mlirContext),
-        &mlirContext));
+        Identifier::get(opConfig.metadata->cppOpName, &mlirContext)));
     if (failed(generateOp(opConfig, genContext))) {
       return 1;
     }
index 1b28bf8..5781870 100644 (file)
@@ -842,8 +842,7 @@ std::string PatternEmitter::handleLocationDirective(DagNode tree) {
   if (tree.getNumArgs() == 1) {
     DagLeaf leaf = tree.getArgAsLeaf(0);
     if (leaf.isStringAttr())
-      return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), "
-                     "rewriter.getContext())",
+      return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"))",
                      leaf.getStringAttr())
           .str();
     return lookUpArgLoc(0);
index 73acb70..b18fe8a 100644 (file)
@@ -151,7 +151,7 @@ TEST(DenseSplatTest, BF16Splat) {
 TEST(DenseSplatTest, StringSplat) {
   MLIRContext context;
   Type stringType =
-      OpaqueType::get(&context, Identifier::get("test", &context), "string");
+      OpaqueType::get(Identifier::get("test", &context), "string");
   StringRef value = "test-string";
   testSplat(stringType, value);
 }
@@ -159,7 +159,7 @@ TEST(DenseSplatTest, StringSplat) {
 TEST(DenseSplatTest, StringAttrSplat) {
   MLIRContext context;
   Type stringType =
-      OpaqueType::get(&context, Identifier::get("test", &context), "string");
+      OpaqueType::get(Identifier::get("test", &context), "string");
   Attribute stringAttr = StringAttr::get("test-string", stringType);
   testSplat(stringType, stringAttr);
 }