From 7f61396cfac5f114707a4240a314dec28e03a1d5 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 12 Nov 2020 22:55:28 -0800 Subject: [PATCH] [mlir][Interfaces] Add implicit casts from concrete operation types to the interfaces they implement. This removes the need to have an explicit `cast<>` given that we always know it `isa` instance of the interface. Differential Revision: https://reviews.llvm.org/D91304 --- mlir/include/mlir/Support/InterfaceSupport.h | 32 ++++++++++++++-------- .../Dialect/Linalg/Transforms/FusionOnTensors.cpp | 5 ++-- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h index fa3ef3e..44b0f67 100644 --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -16,7 +16,6 @@ #include "mlir/Support/TypeID.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/TypeName.h" -#include "llvm/Support/raw_ostream.h" namespace mlir { namespace detail { @@ -75,10 +74,28 @@ public: using InterfaceBase = Interface; + /// This is a special trait that registers a given interface with an object. + template + struct Trait : public BaseTrait { + using ModelT = Model; + + /// Define an accessor for the ID of this interface. + static TypeID getInterfaceID() { return TypeID::get(); } + }; + + /// Construct an interface from an instance of the value type. Interface(ValueT t = ValueT()) : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { - assert((!t || impl) && - "instantiating an interface with an unregistered operation"); + assert((!t || impl) && "expected value to provide interface instance"); + } + + /// Construct an interface instance from a type that implements this + /// interface's trait. + template , T>::value> * = nullptr> + Interface(T t) + : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { + assert((!t || impl) && "expected value to provide interface instance"); } /// Support 'classof' by checking if the given object defines the concrete @@ -88,15 +105,6 @@ public: /// Define an accessor for the ID of this interface. static TypeID getInterfaceID() { return TypeID::get(); } - /// This is a special trait that registers a given interface with an object. - template - struct Trait : public BaseTrait { - using ModelT = Model; - - /// Define an accessor for the ID of this interface. - static TypeID getInterfaceID() { return TypeID::get(); } - }; - protected: /// Get the raw concept in the correct derived concept type. const Concept *getImpl() const { return impl; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 7cb9bb5..abc10e8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -341,10 +341,9 @@ template static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter, Args... args) { if (isa(op.getOperation())) - return cast(rewriter.create(args...).getOperation()); + return rewriter.create(args...); if (isa(op.getOperation())) - return cast( - rewriter.create(args...).getOperation()); + return rewriter.create(args...); llvm_unreachable( "expected only linalg.generic or linalg.indexed_generic ops"); return nullptr; -- 2.7.4