Add support for Operation interfaces.
authorRiver Riddle <riverriddle@google.com>
Mon, 19 Aug 2019 19:43:46 +0000 (12:43 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 19 Aug 2019 19:44:14 +0000 (12:44 -0700)
Operation interfaces, as the name suggests, are those registered at the
Operation level. These interfaces provide an opaque view into derived
operations, by providing a virtual interface that must be implemented. As an
example, the Linalg dialect implements an interface LinalgOp that provides
general queries about some of the dialects library operations. These queries may
provide things like: the number of parallel loops, the number of inputs and
outputs, etc.

Operation interfaces are defined by overriding the CRTP base class OpInterface.
This class takes as a template parameter, a `Traits` class that defines a
Concept and a Model class. These classes provide an implementation of
concept-based polymorphism, where the Concept defines a set of virtual methods
that are overridden by the Model that is templated on the concrete operation
type. It is important to note that these classes should be pure in that they
contain no non-static data members.

PiperOrigin-RevId: 264218741

mlir/g3doc/Interfaces.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td
mlir/include/mlir/Linalg/IR/LinalgOps.h

index b1f823a..666654a 100644 (file)
@@ -101,3 +101,73 @@ InlinerInterface interface(ctx);
 if(!interface.isLegalToInline(...))
    ...
 ```
+
+### Operation Interfaces
+
+Operation interfaces, as the name suggests, are those registered at the
+Operation level. These interfaces provide an opaque view into derived
+operations, by providing a virtual interface that must be implemented. As an
+example, the `Linalg` dialect may implement an interface that provides general
+queries about some of the dialects library operations. These queries may provide
+things like: the number of parallel loops, the number of inputs and outputs,
+etc.
+
+Operation interfaces are defined by overriding the CRTP base class
+`OpInterface`. This class takes as a template parameter, a `Traits` class that
+defines a `Concept` and a `Model` class. These classes provide an implementation
+of concept-based polymorphism, where the Concept defines a set of virtual
+methods that are overridden by the Model that is templated on the concrete
+operation type. It is important to note that these classes should be pure in
+that they contain no non-static data members. Operations that wish to override
+this interface should add the provided trait `OpInterface<..>::Trait` upon
+registration.
+
+```c++
+struct ExampleOpInterfaceTraits {
+/// Define a base concept class that defines the virtual interface that needs
+/// to be overridden.
+struct Concept {
+  virtual ~Concept();
+  virtual unsigned getNumInputs(Operation *op) = 0;
+};
+
+/// Define a model class that specializes a concept on a given operation type.
+template <typename OpT>
+struct Model {
+  /// Override the method to dispatch on the concrete operation.
+  unsigned getNumInputs(Operation *op) final {
+    return llvm::cast<OpT>(op).getNumInputs();
+  }
+};
+};
+
+class ExampleOpInterface : public OpInterface<ExampleOpInterface,
+                                              ExampleOpInterfaceTraits> {
+public:
+  /// The interface dispatches to 'getImpl()', an instance of the concept.
+  unsigned getNumInputs() {
+    return getImpl()->getNumInputs(getOperation());
+  }
+};
+
+```
+
+Once the interface has been defined, it is registered to an operation by adding
+the provided trait `ExampleOpInterface::Trait`. Using this interface is just
+like using any other derived operation type, i.e. casting:
+
+```c++
+/// When defining the operation, the interface is registered via the nested
+/// 'Trait' class provided by the 'OpInterface<>' base class.
+class MyOp : public Op<MyOp, ExampleOpInterface::Trait> {
+public:
+  /// The definition of the interface method on the derived operation.
+  unsigned getNumInputs() { return ...; }
+};
+
+/// Later, we can query if a specific operation(like 'MyOp') overrides the given
+/// interface.
+Operation *op = ...;
+if (ExampleOpInterface example = dyn_cast<ExampleOpInterface>(op))
+ llvm::errs() << "num inputs = " << example.getNumInputs() << "\n";
+```
index eb49c23..f134979 100644 (file)
@@ -1070,6 +1070,19 @@ def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">;
 def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
 
 //===----------------------------------------------------------------------===//
+// OpInterface definitions
+//===----------------------------------------------------------------------===//
+
+// NativeOpInterface corresponds to a specific 'OpInterface' class defined in
+// C++. The purpose to wrap around C++ symbol string with this class is to make
+// interfaces specified for ops in TableGen less alien and more integrated.
+class NativeOpInterface<string prop> : NativeOpTrait<""> {
+  // TODO(riverriddle) Remove when operation interfaces have their own trait
+  // subclass.
+  let trait = prop # "::Trait";
+}
+
+//===----------------------------------------------------------------------===//
 // Op definitions
 //===----------------------------------------------------------------------===//
 
index ed68936..fd35262 100644 (file)
@@ -996,10 +996,121 @@ private:
                               traitID);
   }
 
-  /// Allow access to 'hasTrait'.
+  /// Returns an opaque pointer to a concept instance of the interface with the
+  /// given ID if one was registered to this operation.
+  static void *getRawInterface(ClassID *id) {
+    return InterfaceLookup::template lookup<Traits<ConcreteType>...>(id);
+  }
+
+  struct InterfaceLookup {
+    /// Trait to check if T provides a static 'getInterfaceID' method.
+    template <typename T, typename... Args>
+    using has_get_interface_id = decltype(T::getInterfaceID());
+
+    /// If 'T' is the same interface as 'interfaceID' return the concept
+    /// instance.
+    template <typename T>
+    static typename std::enable_if<is_detected<has_get_interface_id, T>::value,
+                                   void *>::type
+    lookup(ClassID *interfaceID) {
+      return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr;
+    }
+
+    /// 'T' is known to not be an interface, return nullptr.
+    template <typename T>
+    static typename std::enable_if<!is_detected<has_get_interface_id, T>::value,
+                                   void *>::type
+    lookup(ClassID *) {
+      return nullptr;
+    }
+
+    template <typename T, typename T2, typename... Ts>
+    static void *lookup(ClassID *interfaceID) {
+      auto *concept = lookup<T>(interfaceID);
+      return concept ? concept : lookup<T2, Ts...>(interfaceID);
+    }
+  };
+
+  /// Allow access to 'hasTrait' and 'getRawInterface'.
   friend AbstractOperation;
 };
 
+/// This class represents the base of an operation interface. Operation
+/// interfaces provide access to derived *Op properties through an opaquely
+/// Operation instance. Derived interfaces must also provide a 'Traits' class
+/// that defines a 'Concept' and a 'Model' class. The 'Concept' class defines an
+/// abstract virtual interface, where as the 'Model' class implements this
+/// interface for a specific derived *Op type. Both of these classes *must* not
+/// contain non-static data. A simple example is shown below:
+///
+///  struct ExampleOpInterfaceTraits {
+///    struct Concept {
+///      virtual unsigned getNumInputs(Operation *op) = 0;
+///    };
+///    template <typename OpT> class Model {
+///      unsigned getNumInputs(Operation *op) final {
+///        return llvm::cast<OpT>(op).getNumInputs();
+///      }
+///    };
+///  };
+///
+template <typename ConcreteType, typename Traits>
+class OpInterface : public Op<ConcreteType> {
+public:
+  using Concept = typename Traits::Concept;
+  template <typename T> using Model = typename Traits::template Model<T>;
+
+  OpInterface(Operation *op = nullptr)
+      : Op<ConcreteType>(op), impl(op ? getInterfaceFor(op) : nullptr) {
+    assert((!op || impl) &&
+           "instantiating an interface with an unregistered operation");
+  }
+
+  /// Support 'classof' by checking if the given operation defines the concrete
+  /// interface.
+  static bool classof(Operation *op) { return getInterfaceFor(op); }
+
+  /// Define an accessor for the ID of this interface.
+  static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
+
+  /// This is a special trait that registers a given interface with an
+  /// operation.
+  template <typename ConcreteOp>
+  struct Trait : public OpTrait::TraitBase<ConcreteOp, Trait> {
+    /// Define an accessor for the ID of this interface.
+    static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
+
+    /// Provide an accessor to a static instance of the interface model for the
+    /// concrete operation type.
+    /// The implementation is inspired from Sean Parent's concept-based
+    /// polymorphism. A key difference is that the set of classes erased is
+    /// statically known, which alleviates the need for using dynamic memory
+    /// allocation.
+    /// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
+    /// virtual table and generate a singleton object for each instantiation of
+    /// this class.
+    static Concept &instance() {
+      static Model<ConcreteOp> singleton;
+      return singleton;
+    }
+  };
+
+protected:
+  /// Get the raw concept in the correct derived concept type.
+  Concept *getImpl() { return impl; }
+
+private:
+  /// Returns the impl interface instance for the given operation.
+  static Concept *getInterfaceFor(Operation *op) {
+    // Access the raw interface from the abstract operation.
+    auto *abstractOp = op->getAbstractOperation();
+    return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
+  }
+
+  /// A pointer to the impl concept object.
+  Concept *impl;
+};
+
 // These functions are out-of-line implementations of the methods in BinaryOp,
 // which avoids them being template instantiated/duplicated.
 namespace impl {
index 204da29..4871c85 100644 (file)
@@ -140,6 +140,14 @@ public:
     return opProperties & static_cast<OperationProperties>(property);
   }
 
+  /// Returns an instance of the concept object for the given interface if it
+  /// was registered to this operation, null otherwise. This should not be used
+  /// directly.
+  template <typename T> typename T::Concept *getInterface() const {
+    return reinterpret_cast<typename T::Concept *>(
+        getRawInterface(T::getInterfaceID()));
+  }
+
   /// Returns if the operation has a particular trait.
   template <template <typename T> class Trait> bool hasTrait() const {
     return hasRawTrait(ClassID::getID<Trait>());
@@ -156,7 +164,7 @@ public:
     return AbstractOperation(
         T::getOperationName(), dialect, T::getOperationProperties(), T::classof,
         T::parseAssembly, T::printAssembly, T::verifyInvariants, T::foldHook,
-        T::getCanonicalizationPatterns, T::hasTrait);
+        T::getCanonicalizationPatterns, T::getRawInterface, T::hasTrait);
   }
 
 private:
@@ -170,16 +178,23 @@ private:
                                 SmallVectorImpl<OpFoldResult> &results),
       void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
                                           MLIRContext *context),
+      void *(&getRawInterface)(ClassID *interfaceID),
       bool (&hasTrait)(ClassID *traitID))
       : name(name), dialect(dialect), classof(classof),
         parseAssembly(parseAssembly), printAssembly(printAssembly),
         verifyInvariants(verifyInvariants), foldHook(foldHook),
         getCanonicalizationPatterns(getCanonicalizationPatterns),
-        opProperties(opProperties), hasRawTrait(hasTrait) {}
+        opProperties(opProperties), getRawInterface(getRawInterface),
+        hasRawTrait(hasTrait) {}
 
   /// The properties of the operation.
   const OperationProperties opProperties;
 
+  /// Returns a raw instance of the concept for the given interface id if it is
+  /// registered to this operation, nullptr otherwise. This should not be used
+  /// directly.
+  void *(&getRawInterface)(ClassID *interfaceID);
+
   /// This hook returns if the operation contains the trait corresponding
   /// to the given ClassID.
   bool (&hasRawTrait)(ClassID *traitID);
index 998d68b..d807b9f 100644 (file)
@@ -68,12 +68,17 @@ LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
 
 def ViewTraits : NativeOpTrait<"linalg::ViewTraits">;
 
+// The linalg 'LinalgLibraryInterface' provides access to the 'LinalgOp'
+// interface.
+def LinalgLibraryInterface : NativeOpInterface<"LinalgOp">;
+
 // Base Tablegen class for Linalg ops.
 // Linalg ops that correspond to library calls operate on linalg::View as their
 // first operands. These may be optionally followed by non-view operands
 // depending on the specific Linalg op.
 class LinalgLibraryBase_Op<string mnemonic, list<OpTrait> props>
-  : Op<Linalg_Dialect, mnemonic, !listconcat(props, [ViewTraits])> {
+  : Op<Linalg_Dialect, mnemonic,
+       !listconcat(props, [ViewTraits, LinalgLibraryInterface])> {
   let parser = [{ return parseLinalgLibraryOp(parser, result); }];
   let printer = [{ printLinalgLibraryOp(p, *this); }];
 }
index 140b9bc..761fed8 100644 (file)
@@ -55,12 +55,6 @@ namespace linalg {
 ///   name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl`
 std::string generateLibraryCallName(Operation *op);
 
-#define GET_OP_CLASSES
-#include "mlir/Linalg/IR/LinalgOps.h.inc"
-
-#define GET_OP_CLASSES
-#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc"
-
 /// Returns the list of maps that map loops to operands of a Linalg op.
 /// The i-th affine map identifies loop indices to subscripts that are used when
 /// accessing the i-th operand.
@@ -79,78 +73,8 @@ std::string generateLibraryCallName(Operation *op);
 /// Only permutation maps are currently supported.
 SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
 
-llvm::raw_ostream &operator<<(llvm::raw_ostream &os, SubViewOp::Range &range);
-
-/// A LinalgOp behaves like a base class for the Linalg operations that are
-/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
-/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
-/// method and dispatches to the appropriate LinalgLibraryOp.
-/// This allows writing generic passes, like tiling, for all current and future
-/// LinalgOps without requiring templating and dispatch in multiple places.
-class LinalgOp : public Op<LinalgOp> {
-public:
-  using Op::Op;
-
-  LinalgOp(Operation *op) : Op<LinalgOp>(op) {
-    impl = ModelDispatch<
-#define GET_OP_LIST
-#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
-        >::dispatch(op);
-  }
-
-  static bool classof(Operation *op) {
-    return ModelDispatch<
-#define GET_OP_LIST
-#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
-        >::classof(op);
-  }
-
-  unsigned getNumParallelLoops() {
-    return impl->getNumParallelLoops(getOperation());
-  }
-  unsigned getNumReductionLoops() {
-    return impl->getNumReductionLoops(getOperation());
-  }
-  unsigned getNumWindowLoops() {
-    return impl->getNumWindowLoops(getOperation());
-  }
-  unsigned getNumLoops() {
-    return getNumParallelLoops() + getNumReductionLoops() + getNumWindowLoops();
-  }
-  unsigned getNumInputs() { return impl->getNumInputs(getOperation()); }
-  unsigned getNumOutputs() { return impl->getNumOutputs(getOperation()); }
-  unsigned getNumInputsAndOutputs() {
-    return impl->getNumInputsAndOutputs(getOperation());
-  }
-  Value *getInput(unsigned i) { return impl->getInput(getOperation(), i); }
-  llvm::Optional<unsigned> getIndexOfInput(Value *view) {
-    return impl->getIndexOfInput(getOperation(), view);
-  }
-  ViewType getInputViewType(unsigned i) {
-    return impl->getInputViewType(getOperation(), i);
-  }
-  Operation::operand_range getInputs() {
-    return impl->getInputs(getOperation());
-  }
-  Value *getOutput(unsigned i) { return impl->getOutput(getOperation(), i); }
-  llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
-    return impl->getIndexOfOutput(getOperation(), view);
-  }
-  ViewType getOutputViewType(unsigned i) {
-    return impl->getOutputViewType(getOperation(), i);
-  }
-  Operation::operand_range getOutputs() {
-    return impl->getOutputs(getOperation());
-  }
-  Operation::operand_range getInputsAndOutputs() {
-    return impl->getInputsAndOutputs(getOperation());
-  }
-  LinalgOp create(OpBuilder &builder, Location loc, ArrayRef<Value *> operands,
-                  ArrayRef<NamedAttribute> attributes) {
-    return LinalgOp(impl->create(builder, loc, operands, attributes));
-  }
-
-private:
+namespace detail {
+struct LinalgOpInterfaceTraits {
   struct Concept {
     virtual ~Concept() = default;
     virtual unsigned getNumInputs(Operation *op) = 0;
@@ -175,20 +99,7 @@ private:
                               ArrayRef<NamedAttribute> attributes) = 0;
   };
 
-  /// The implementation is inspired from Sean Parent's concept-based
-  /// polymorphism. A key difference is that the set of classes erased is
-  /// statically known, which alleviates the need for using dynamic memory
-  /// allocation.
-  /// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
-  /// virtual table and generate a singleton object for each instantiation of
-  /// this class.
-  /// We pay the cost of initialization once on construction (find which class
-  /// to dispatch to) and then a virtual dispatch on every call.
   template <typename ConcreteOp> struct Model : public Concept {
-    static Model<ConcreteOp> &instance() {
-      static Model<ConcreteOp> singleton;
-      return singleton;
-    }
     unsigned getNumInputs(Operation *op) override {
       return cast<ConcreteOp>(op).getNumInputs();
     }
@@ -243,29 +154,75 @@ private:
                                         attributes);
     }
   };
-  Concept *impl;
-
-  template <typename... Types> struct ModelDispatch;
+};
+} // namespace detail
 
-  template <typename First, typename... Rest>
-  struct ModelDispatch<First, Rest...> {
-    static bool classof(Operation *op) {
-      return isa<First>(op) || ModelDispatch<Rest...>::classof(op);
-    }
-    static Concept *dispatch(Operation *op) {
-      return isa<First>(op) ? &Model<First>::instance()
-                            : ModelDispatch<Rest...>::dispatch(op);
-    }
-  };
+/// A LinalgOp behaves like a base class for the Linalg operations that are
+/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
+/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
+/// method and dispatches to the appropriate LinalgLibraryOp.
+/// This allows writing generic passes, like tiling, for all current and future
+/// LinalgOps without requiring templating and dispatch in multiple places.
+class LinalgOp : public OpInterface<LinalgOp, detail::LinalgOpInterfaceTraits> {
+public:
+  using OpInterface<LinalgOp, detail::LinalgOpInterfaceTraits>::OpInterface;
 
-  template <typename...> struct ModelDispatch {
-    static bool classof(Operation *op) { return false; }
-    static Concept *dispatch(Operation *op) {
-      llvm_unreachable("Invalid LinalgOp");
-    }
-  };
+  unsigned getNumParallelLoops() {
+    return getImpl()->getNumParallelLoops(getOperation());
+  }
+  unsigned getNumReductionLoops() {
+    return getImpl()->getNumReductionLoops(getOperation());
+  }
+  unsigned getNumWindowLoops() {
+    return getImpl()->getNumWindowLoops(getOperation());
+  }
+  unsigned getNumLoops() {
+    return getNumParallelLoops() + getNumReductionLoops() + getNumWindowLoops();
+  }
+  unsigned getNumInputs() { return getImpl()->getNumInputs(getOperation()); }
+  unsigned getNumOutputs() { return getImpl()->getNumOutputs(getOperation()); }
+  unsigned getNumInputsAndOutputs() {
+    return getImpl()->getNumInputsAndOutputs(getOperation());
+  }
+  Value *getInput(unsigned i) { return getImpl()->getInput(getOperation(), i); }
+  llvm::Optional<unsigned> getIndexOfInput(Value *view) {
+    return getImpl()->getIndexOfInput(getOperation(), view);
+  }
+  ViewType getInputViewType(unsigned i) {
+    return getImpl()->getInputViewType(getOperation(), i);
+  }
+  Operation::operand_range getInputs() {
+    return getImpl()->getInputs(getOperation());
+  }
+  Value *getOutput(unsigned i) {
+    return getImpl()->getOutput(getOperation(), i);
+  }
+  llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
+    return getImpl()->getIndexOfOutput(getOperation(), view);
+  }
+  ViewType getOutputViewType(unsigned i) {
+    return getImpl()->getOutputViewType(getOperation(), i);
+  }
+  Operation::operand_range getOutputs() {
+    return getImpl()->getOutputs(getOperation());
+  }
+  Operation::operand_range getInputsAndOutputs() {
+    return getImpl()->getInputsAndOutputs(getOperation());
+  }
+  LinalgOp create(OpBuilder &builder, Location loc, ArrayRef<Value *> operands,
+                  ArrayRef<NamedAttribute> attributes) {
+    return LinalgOp(getImpl()->create(builder, loc, operands, attributes));
+  }
 };
 
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgOps.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc"
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, SubViewOp::Range &range);
+
 } // namespace linalg
 } // namespace mlir