[mlir] Update Location to use new casting infra
authorNick Kreeger <nick.kreeger@gmail.com>
Mon, 24 Oct 2022 19:32:12 +0000 (14:32 -0500)
committerNick Kreeger <nick.kreeger@gmail.com>
Mon, 24 Oct 2022 19:32:12 +0000 (14:32 -0500)
This allows for using the llvm namespace cast methods instead of the ones on the Location class. The Location class method are kept for now, but we'll want to remove these eventually (with a really long lead time).

Related change: https://reviews.llvm.org/D135870

Differential Revision: https://reviews.llvm.org/D136520

mlir/include/mlir/IR/Location.h
mlir/lib/AsmParser/Parser.cpp
mlir/lib/IR/Diagnostics.cpp
mlir/lib/IR/Location.cpp

index fc3ee12..6691c2f 100644 (file)
@@ -65,15 +65,15 @@ public:
   /// Type casting utilities on the underlying location.
   template <typename U>
   bool isa() const {
-    return impl.isa<U>();
+    return llvm::isa<U>(*this);
   }
   template <typename U>
   U dyn_cast() const {
-    return impl.dyn_cast<U>();
+    return llvm::dyn_cast<U>(*this);
   }
   template <typename U>
   U cast() const {
-    return impl.cast<U>();
+    return llvm::cast<U>(*this);
   }
 
   /// Comparison operators.
@@ -170,6 +170,39 @@ public:
       PointerLikeTypeTraits<mlir::Attribute>::NumLowBitsAvailable;
 };
 
+/// The constructors in mlir::Location ensure that the class is a non-nullable
+/// wrapper around mlir::LocationAttr. Override default behavior and always
+/// return true for isPresent().
+template <>
+struct ValueIsPresent<mlir::Location> {
+  using UnwrappedType = mlir::Location;
+  static inline bool isPresent(const mlir::Location &location) { return true; }
+};
+
+/// Add support for llvm style casts. We provide a cast between To and From if
+/// From is mlir::Location or derives from it.
+template <typename To, typename From>
+struct CastInfo<To, From,
+                std::enable_if_t<
+                    std::is_same_v<mlir::Location, std::remove_const_t<From>> ||
+                    std::is_base_of_v<mlir::Location, From>>>
+    : DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+  static inline bool isPossible(mlir::Location location) {
+    /// Return a constant true instead of a dynamic true when casting to self or
+    /// up the hierarchy. Additionally, all casting info is deferred to the
+    /// wrapped mlir::LocationAttr instance stored in mlir::Location.
+    return std::is_same_v<To, std::remove_const_t<From>> ||
+           isa<To>(static_cast<mlir::LocationAttr>(location));
+  }
+
+  static inline To castFailed() { return To(); }
+
+  static inline To doCast(mlir::Location location) {
+    return To(location->getImpl());
+  }
+};
+
 } // namespace llvm
 
 #endif
index ac19412..4bb4fdf 100644 (file)
@@ -768,7 +768,7 @@ ParseResult OperationParser::finalize() {
   auto &attributeAliases = state.symbols.attributeAliasDefinitions;
   auto locID = TypeID::get<DeferredLocInfo *>();
   auto resolveLocation = [&, this](auto &opOrArgument) -> LogicalResult {
-    auto fwdLoc = opOrArgument.getLoc().template dyn_cast<OpaqueLoc>();
+    auto fwdLoc = dyn_cast<OpaqueLoc>(opOrArgument.getLoc());
     if (!fwdLoc || fwdLoc.getUnderlyingTypeID() != locID)
       return success();
     auto locInfo = deferredLocsReferences[fwdLoc.getUnderlyingLocation()];
@@ -776,7 +776,7 @@ ParseResult OperationParser::finalize() {
     if (!attr)
       return this->emitError(locInfo.loc)
              << "operation location alias was never defined";
-    auto locAttr = attr.dyn_cast<LocationAttr>();
+    auto locAttr = dyn_cast<LocationAttr>(attr);
     if (!locAttr)
       return this->emitError(locInfo.loc)
              << "expected location, but found '" << attr << "'";
@@ -1930,7 +1930,7 @@ ParseResult OperationParser::parseLocationAlias(LocationAttr &loc) {
   // If this alias can be resolved, do it now.
   Attribute attr = state.symbols.attributeAliasDefinitions.lookup(identifier);
   if (attr) {
-    if (!(loc = attr.dyn_cast<LocationAttr>()))
+    if (!(loc = dyn_cast<LocationAttr>(attr)))
       return emitError(tok.getLoc())
              << "expected location, but found '" << attr << "'";
   } else {
index 6d6f2b7..14aa44f 100644 (file)
@@ -404,12 +404,12 @@ static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
 
 /// Return a processable CallSiteLoc from the given location.
 static Optional<CallSiteLoc> getCallSiteLoc(Location loc) {
-  if (auto nameLoc = loc.dyn_cast<NameLoc>())
-    return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc());
-  if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
+  if (auto nameLoc = dyn_cast<NameLoc>(loc))
+    return getCallSiteLoc(cast<NameLoc>(loc).getChildLoc());
+  if (auto callLoc = dyn_cast<CallSiteLoc>(loc))
     return callLoc;
-  if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
-    for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
+  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
+    for (auto subLoc : cast<FusedLoc>(loc).getLocations()) {
       if (auto callLoc = getCallSiteLoc(subLoc)) {
         return callLoc;
       }
index cd19d59..8a8801d 100644 (file)
@@ -105,7 +105,7 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
   for (auto loc : locs) {
     // If the location is a fused location we decompose it if it has no
     // metadata or the metadata is the same as the top level metadata.
-    if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
+    if (auto fusedLoc = llvm::dyn_cast<FusedLoc>(loc)) {
       if (fusedLoc.getMetadata() == metadata) {
         // UnknownLoc's have already been removed from FusedLocs so we can
         // simply add all of the internal locations.