[mlir] Add CastInfo for mlir classes subclassing from PointerUnion
authorTres Popp <tpopp@google.com>
Tue, 23 May 2023 09:52:31 +0000 (11:52 +0200)
committerTres Popp <tpopp@google.com>
Fri, 26 May 2023 05:47:03 +0000 (07:47 +0200)
This is required to use the function variants of cast/isa/dyn_cast/etc
on them.

Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443

mlir/include/mlir/Analysis/DataFlowFramework.h
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Unit.h
mlir/include/mlir/Interfaces/CallInterfaces.h

index 68f3db7..9649f91 100644 (file)
@@ -470,6 +470,16 @@ namespace llvm {
 template <>
 struct DenseMapInfo<mlir::ProgramPoint>
     : public DenseMapInfo<mlir::ProgramPoint::ParentTy> {};
+
+// Allow llvm::cast style functions.
+template <typename To>
+struct CastInfo<To, mlir::ProgramPoint>
+    : public CastInfo<To, mlir::ProgramPoint::PointerUnion> {};
+
+template <typename To>
+struct CastInfo<To, const mlir::ProgramPoint>
+    : public CastInfo<To, const mlir::ProgramPoint::PointerUnion> {};
+
 } // end namespace llvm
 
 #endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
index 3725cdd..5f88e10 100644 (file)
@@ -236,4 +236,17 @@ SmallVector<IntT> convertArrayToIndices(ArrayAttr attrs) {
 } // namespace LLVM
 } // namespace mlir
 
+namespace llvm {
+
+// Allow llvm::cast style functions.
+template <typename To>
+struct CastInfo<To, mlir::LLVM::GEPArg>
+    : public CastInfo<To, mlir::LLVM::GEPArg::PointerUnion> {};
+
+template <typename To>
+struct CastInfo<To, const mlir::LLVM::GEPArg>
+    : public CastInfo<To, const mlir::LLVM::GEPArg::PointerUnion> {};
+
+} // namespace llvm
+
 #endif // MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
index 71864de..f3734dc 100644 (file)
@@ -269,15 +269,35 @@ public:
   void dump() const { llvm::errs() << *this << "\n"; }
 };
 
+// Temporarily exit the MLIR namespace to add casting support as later code in
+// this uses it. The CastInfo must come after the OpFoldResult definition and
+// before any cast function calls depending on CastInfo.
+
+} // namespace mlir
+
+namespace llvm {
+
+// Allow llvm::cast style functions.
+template <typename To>
+struct CastInfo<To, mlir::OpFoldResult>
+    : public CastInfo<To, mlir::OpFoldResult::PointerUnion> {};
+
+template <typename To>
+struct CastInfo<To, const mlir::OpFoldResult>
+    : public CastInfo<To, const mlir::OpFoldResult::PointerUnion> {};
+
+} // namespace llvm
+
+namespace mlir {
+
 /// Allow printing to a stream.
 inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
-  if (Value value = ofr.dyn_cast<Value>())
+  if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
     value.print(os);
   else
-    ofr.dyn_cast<Attribute>().print(os);
+    llvm::dyn_cast_if_present<Attribute>(ofr).print(os);
   return os;
 }
-
 /// Allow printing to a stream.
 inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
   op.print(os, OpPrintingFlags().useLocalScope());
@@ -1554,7 +1574,7 @@ foldTrait(Operation *op, ArrayRef<Attribute> operands,
     return failure();
 
   if (OpFoldResult result = Trait::foldTrait(op, operands)) {
-    if (result.template dyn_cast<Value>() != op->getResult(0))
+    if (llvm::dyn_cast_if_present<Value>(result) != op->getResult(0))
       results.push_back(result);
     return success();
   }
@@ -1903,7 +1923,8 @@ private:
 
     // If the fold failed or was in-place, try to fold the traits of the
     // operation.
-    if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
+    if (!result ||
+        llvm::dyn_cast_if_present<Value>(result) == op->getResult(0)) {
       if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
               op, operands, results)))
         return success();
@@ -2119,7 +2140,6 @@ struct DenseMapInfo<T,
   }
   static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
 };
-
 } // namespace llvm
 
 #endif
index 033dab5..63117a7 100644 (file)
@@ -39,4 +39,17 @@ raw_ostream &operator<<(raw_ostream &os, const IRUnit &unit);
 
 } // end namespace mlir
 
+namespace llvm {
+
+// Allow llvm::cast style functions.
+template <typename To>
+struct CastInfo<To, mlir::IRUnit>
+    : public CastInfo<To, mlir::IRUnit::PointerUnion> {};
+
+template <typename To>
+struct CastInfo<To, const mlir::IRUnit>
+    : public CastInfo<To, const mlir::IRUnit::PointerUnion> {};
+
+} // namespace llvm
+
 #endif // MLIR_IR_UNIT_H
index 26a245e..7dbcddb 100644 (file)
@@ -30,10 +30,15 @@ struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
 
 namespace llvm {
 
+// Allow llvm::cast style functions.
 template <typename To>
 struct CastInfo<To, mlir::CallInterfaceCallable>
     : public CastInfo<To, mlir::CallInterfaceCallable::PointerUnion> {};
 
+template <typename To>
+struct CastInfo<To, const mlir::CallInterfaceCallable>
+    : public CastInfo<To, const mlir::CallInterfaceCallable::PointerUnion> {};
+
 } // namespace llvm
 
 #endif // MLIR_INTERFACES_CALLINTERFACES_H