From 1b8465aac4368c64d3e78ebd94fb8ca048b9e801 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Tue, 23 May 2023 11:52:31 +0200 Subject: [PATCH] [mlir] Add CastInfo for mlir classes subclassing from PointerUnion MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit This is required to use the function variants of cast/isa/dyn_cast/etc on them. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 --- mlir/include/mlir/Analysis/DataFlowFramework.h | 10 ++++++++ mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 13 +++++++++++ mlir/include/mlir/IR/OpDefinition.h | 32 +++++++++++++++++++++----- mlir/include/mlir/IR/Unit.h | 13 +++++++++++ mlir/include/mlir/Interfaces/CallInterfaces.h | 5 ++++ 5 files changed, 67 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h index 68f3db7..9649f91 100644 --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -470,6 +470,16 @@ namespace llvm { template <> struct DenseMapInfo : public DenseMapInfo {}; + +// Allow llvm::cast style functions. +template +struct CastInfo + : public CastInfo {}; + +template +struct CastInfo + : public CastInfo {}; + } // end namespace llvm #endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 3725cdd..5f88e10 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -236,4 +236,17 @@ SmallVector convertArrayToIndices(ArrayAttr attrs) { } // namespace LLVM } // namespace mlir +namespace llvm { + +// Allow llvm::cast style functions. +template +struct CastInfo + : public CastInfo {}; + +template +struct CastInfo + : public CastInfo {}; + +} // namespace llvm + #endif // MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 71864def..f3734dc 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -269,15 +269,35 @@ public: void dump() const { llvm::errs() << *this << "\n"; } }; +// Temporarily exit the MLIR namespace to add casting support as later code in +// this uses it. The CastInfo must come after the OpFoldResult definition and +// before any cast function calls depending on CastInfo. + +} // namespace mlir + +namespace llvm { + +// Allow llvm::cast style functions. +template +struct CastInfo + : public CastInfo {}; + +template +struct CastInfo + : public CastInfo {}; + +} // namespace llvm + +namespace mlir { + /// Allow printing to a stream. inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) { - if (Value value = ofr.dyn_cast()) + if (Value value = llvm::dyn_cast_if_present(ofr)) value.print(os); else - ofr.dyn_cast().print(os); + llvm::dyn_cast_if_present(ofr).print(os); return os; } - /// Allow printing to a stream. inline raw_ostream &operator<<(raw_ostream &os, OpState op) { op.print(os, OpPrintingFlags().useLocalScope()); @@ -1554,7 +1574,7 @@ foldTrait(Operation *op, ArrayRef operands, return failure(); if (OpFoldResult result = Trait::foldTrait(op, operands)) { - if (result.template dyn_cast() != op->getResult(0)) + if (llvm::dyn_cast_if_present(result) != op->getResult(0)) results.push_back(result); return success(); } @@ -1903,7 +1923,8 @@ private: // If the fold failed or was in-place, try to fold the traits of the // operation. - if (!result || result.template dyn_cast() == op->getResult(0)) { + if (!result || + llvm::dyn_cast_if_present(result) == op->getResult(0)) { if (succeeded(op_definition_impl::foldTraits...>( op, operands, results))) return success(); @@ -2119,7 +2140,6 @@ struct DenseMapInfo +struct CastInfo + : public CastInfo {}; + +template +struct CastInfo + : public CastInfo {}; + +} // namespace llvm + #endif // MLIR_IR_UNIT_H diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h index 26a245e..7dbcddb 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.h +++ b/mlir/include/mlir/Interfaces/CallInterfaces.h @@ -30,10 +30,15 @@ struct CallInterfaceCallable : public PointerUnion { namespace llvm { +// Allow llvm::cast style functions. template struct CastInfo : public CastInfo {}; +template +struct CastInfo + : public CastInfo {}; + } // namespace llvm #endif // MLIR_INTERFACES_CALLINTERFACES_H -- 2.7.4