From 18066b52c2e2676438676cb069cf882e609288fe Mon Sep 17 00:00:00 2001 From: Nick Kreeger Date: Mon, 24 Oct 2022 14:32:12 -0500 Subject: [PATCH] [mlir] Update Location to use new casting infra This allows for using the llvm namespace cast methods instead of the ones on the Location class. The Location class method are kept for now, but we'll want to remove these eventually (with a really long lead time). Related change: https://reviews.llvm.org/D135870 Differential Revision: https://reviews.llvm.org/D136520 --- mlir/include/mlir/IR/Location.h | 39 ++++++++++++++++++++++++++++++++++++--- mlir/lib/AsmParser/Parser.cpp | 6 +++--- mlir/lib/IR/Diagnostics.cpp | 10 +++++----- mlir/lib/IR/Location.cpp | 2 +- 4 files changed, 45 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h index fc3ee12..6691c2f 100644 --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -65,15 +65,15 @@ public: /// Type casting utilities on the underlying location. template bool isa() const { - return impl.isa(); + return llvm::isa(*this); } template U dyn_cast() const { - return impl.dyn_cast(); + return llvm::dyn_cast(*this); } template U cast() const { - return impl.cast(); + return llvm::cast(*this); } /// Comparison operators. @@ -170,6 +170,39 @@ public: PointerLikeTypeTraits::NumLowBitsAvailable; }; +/// The constructors in mlir::Location ensure that the class is a non-nullable +/// wrapper around mlir::LocationAttr. Override default behavior and always +/// return true for isPresent(). +template <> +struct ValueIsPresent { + using UnwrappedType = mlir::Location; + static inline bool isPresent(const mlir::Location &location) { return true; } +}; + +/// Add support for llvm style casts. We provide a cast between To and From if +/// From is mlir::Location or derives from it. +template +struct CastInfo> || + std::is_base_of_v>> + : DefaultDoCastIfPossible> { + + static inline bool isPossible(mlir::Location location) { + /// Return a constant true instead of a dynamic true when casting to self or + /// up the hierarchy. Additionally, all casting info is deferred to the + /// wrapped mlir::LocationAttr instance stored in mlir::Location. + return std::is_same_v> || + isa(static_cast(location)); + } + + static inline To castFailed() { return To(); } + + static inline To doCast(mlir::Location location) { + return To(location->getImpl()); + } +}; + } // namespace llvm #endif diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index ac19412..4bb4fdf 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -768,7 +768,7 @@ ParseResult OperationParser::finalize() { auto &attributeAliases = state.symbols.attributeAliasDefinitions; auto locID = TypeID::get(); auto resolveLocation = [&, this](auto &opOrArgument) -> LogicalResult { - auto fwdLoc = opOrArgument.getLoc().template dyn_cast(); + auto fwdLoc = dyn_cast(opOrArgument.getLoc()); if (!fwdLoc || fwdLoc.getUnderlyingTypeID() != locID) return success(); auto locInfo = deferredLocsReferences[fwdLoc.getUnderlyingLocation()]; @@ -776,7 +776,7 @@ ParseResult OperationParser::finalize() { if (!attr) return this->emitError(locInfo.loc) << "operation location alias was never defined"; - auto locAttr = attr.dyn_cast(); + auto locAttr = dyn_cast(attr); if (!locAttr) return this->emitError(locInfo.loc) << "expected location, but found '" << attr << "'"; @@ -1930,7 +1930,7 @@ ParseResult OperationParser::parseLocationAlias(LocationAttr &loc) { // If this alias can be resolved, do it now. Attribute attr = state.symbols.attributeAliasDefinitions.lookup(identifier); if (attr) { - if (!(loc = attr.dyn_cast())) + if (!(loc = dyn_cast(attr))) return emitError(tok.getLoc()) << "expected location, but found '" << attr << "'"; } else { diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 6d6f2b7..14aa44f 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -404,12 +404,12 @@ static Optional getFileLineColLoc(Location loc) { /// Return a processable CallSiteLoc from the given location. static Optional getCallSiteLoc(Location loc) { - if (auto nameLoc = loc.dyn_cast()) - return getCallSiteLoc(loc.cast().getChildLoc()); - if (auto callLoc = loc.dyn_cast()) + if (auto nameLoc = dyn_cast(loc)) + return getCallSiteLoc(cast(loc).getChildLoc()); + if (auto callLoc = dyn_cast(loc)) return callLoc; - if (auto fusedLoc = loc.dyn_cast()) { - for (auto subLoc : loc.cast().getLocations()) { + if (auto fusedLoc = dyn_cast(loc)) { + for (auto subLoc : cast(loc).getLocations()) { if (auto callLoc = getCallSiteLoc(subLoc)) { return callLoc; } diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp index cd19d59..8a8801d 100644 --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -105,7 +105,7 @@ Location FusedLoc::get(ArrayRef locs, Attribute metadata, for (auto loc : locs) { // If the location is a fused location we decompose it if it has no // metadata or the metadata is the same as the top level metadata. - if (auto fusedLoc = loc.dyn_cast()) { + if (auto fusedLoc = llvm::dyn_cast(loc)) { if (fusedLoc.getMetadata() == metadata) { // UnknownLoc's have already been removed from FusedLocs so we can // simply add all of the internal locations. -- 2.7.4