[MLIR][Operation] Simplify Operation casting, NFC
authorbzcheeseman <12992886+bzcheeseman@users.noreply.github.com>
Wed, 11 May 2022 19:25:04 +0000 (15:25 -0400)
committerbzcheeseman <12992886+bzcheeseman@users.noreply.github.com>
Thu, 12 May 2022 04:17:01 +0000 (00:17 -0400)
We can simplify the code needed to implement dyn_cast/cast/isa support for MLIR operations with documented interfaces via the CastInfo structures. This will also provide an example of how to use CastInfo.

Depends on D123901

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D124963

mlir/include/mlir/IR/Operation.h

index 4594cb6..86b5a94 100644 (file)
@@ -172,7 +172,8 @@ public:
   Operation *getParentOp() { return block ? block->getParentOp() : nullptr; }
 
   /// Return the closest surrounding parent operation that is of type 'OpTy'.
-  template <typename OpTy> OpTy getParentOfType() {
+  template <typename OpTy>
+  OpTy getParentOfType() {
     auto *op = this;
     while ((op = op->getParentOp()))
       if (auto parentOp = dyn_cast<OpTy>(op))
@@ -521,14 +522,16 @@ public:
 
   /// Returns true if the operation was registered with a particular trait, e.g.
   /// hasTrait<OperandsAreSignlessIntegerLike>().
-  template <template <typename T> class Trait> bool hasTrait() {
+  template <template <typename T> class Trait>
+  bool hasTrait() {
     return name.hasTrait<Trait>();
   }
 
   /// Returns true if the operation *might* have the provided trait. This
   /// means that either the operation is unregistered, or it was registered with
   /// the provide trait.
-  template <template <typename T> class Trait> bool mightHaveTrait() {
+  template <template <typename T> class Trait>
+  bool mightHaveTrait() {
     return name.mightHaveTrait<Trait>();
   }
 
@@ -804,34 +807,33 @@ inline raw_ostream &operator<<(raw_ostream &os, const Operation &op) {
 } // namespace mlir
 
 namespace llvm {
-/// Provide isa functionality for operation casts.
-template <typename T> struct isa_impl<T, ::mlir::Operation> {
-  static inline bool doit(const ::mlir::Operation &op) {
-    return T::classof(const_cast<::mlir::Operation *>(&op));
-  }
-};
-
-/// Allow isa<Operation *> on operations.
-template <> struct isa_impl<::mlir::Operation *, ::mlir::Operation> {
-  static inline bool doit(const ::mlir::Operation &op) { return true; }
-};
-
-/// Provide specializations for operation casts as the resulting T is value
-/// typed.
-template <typename T> struct cast_retty_impl<T, ::mlir::Operation *> {
-  using ret_type = T;
-};
-template <typename T> struct cast_retty_impl<T, ::mlir::Operation> {
-  using ret_type = T;
-};
-template <class T>
-struct cast_convert_val<T, ::mlir::Operation, ::mlir::Operation> {
-  static T doit(::mlir::Operation &val) { return T(&val); }
+/// Cast from an (const) Operation * to a derived operation type.
+template <typename T>
+struct CastInfo<T, ::mlir::Operation *>
+    : public ValueFromPointerCast<T, ::mlir::Operation,
+                                  CastInfo<T, ::mlir::Operation *>> {
+  static bool isPossible(::mlir::Operation *op) { return T::classof(op); }
 };
-template <class T>
-struct cast_convert_val<T, ::mlir::Operation *, ::mlir::Operation *> {
-  static T doit(::mlir::Operation *val) { return T(val); }
+template <typename T>
+struct CastInfo<T, const ::mlir::Operation *>
+    : public ConstStrippingForwardingCast<T, const ::mlir::Operation *,
+                                          CastInfo<T, ::mlir::Operation *>> {};
+
+/// Cast from an (const) Operation & to a derived operation type.
+template <typename T>
+struct CastInfo<T, ::mlir::Operation>
+    : public NullableValueCastFailed<T>,
+      public DefaultDoCastIfPossible<T, ::mlir::Operation &,
+                                     CastInfo<T, ::mlir::Operation>> {
+  // Provide isPossible here because here we have the const-stripping from
+  // ConstStrippingCast.
+  static bool isPossible(::mlir::Operation &val) { return T::classof(&val); }
+  static T doCast(::mlir::Operation &val) { return T(&val); }
 };
+template <typename T>
+struct CastInfo<T, const ::mlir::Operation>
+    : public ConstStrippingForwardingCast<T, const ::mlir::Operation,
+                                          CastInfo<T, ::mlir::Operation>> {};
 } // namespace llvm
 
 #endif // MLIR_IR_OPERATION_H