if(!interface.isLegalToInline(...))
...
```
+
+### Operation Interfaces
+
+Operation interfaces, as the name suggests, are those registered at the
+Operation level. These interfaces provide an opaque view into derived
+operations, by providing a virtual interface that must be implemented. As an
+example, the `Linalg` dialect may implement an interface that provides general
+queries about some of the dialects library operations. These queries may provide
+things like: the number of parallel loops, the number of inputs and outputs,
+etc.
+
+Operation interfaces are defined by overriding the CRTP base class
+`OpInterface`. This class takes as a template parameter, a `Traits` class that
+defines a `Concept` and a `Model` class. These classes provide an implementation
+of concept-based polymorphism, where the Concept defines a set of virtual
+methods that are overridden by the Model that is templated on the concrete
+operation type. It is important to note that these classes should be pure in
+that they contain no non-static data members. Operations that wish to override
+this interface should add the provided trait `OpInterface<..>::Trait` upon
+registration.
+
+```c++
+struct ExampleOpInterfaceTraits {
+/// Define a base concept class that defines the virtual interface that needs
+/// to be overridden.
+struct Concept {
+ virtual ~Concept();
+ virtual unsigned getNumInputs(Operation *op) = 0;
+};
+
+/// Define a model class that specializes a concept on a given operation type.
+template <typename OpT>
+struct Model {
+ /// Override the method to dispatch on the concrete operation.
+ unsigned getNumInputs(Operation *op) final {
+ return llvm::cast<OpT>(op).getNumInputs();
+ }
+};
+};
+
+class ExampleOpInterface : public OpInterface<ExampleOpInterface,
+ ExampleOpInterfaceTraits> {
+public:
+ /// The interface dispatches to 'getImpl()', an instance of the concept.
+ unsigned getNumInputs() {
+ return getImpl()->getNumInputs(getOperation());
+ }
+};
+
+```
+
+Once the interface has been defined, it is registered to an operation by adding
+the provided trait `ExampleOpInterface::Trait`. Using this interface is just
+like using any other derived operation type, i.e. casting:
+
+```c++
+/// When defining the operation, the interface is registered via the nested
+/// 'Trait' class provided by the 'OpInterface<>' base class.
+class MyOp : public Op<MyOp, ExampleOpInterface::Trait> {
+public:
+ /// The definition of the interface method on the derived operation.
+ unsigned getNumInputs() { return ...; }
+};
+
+/// Later, we can query if a specific operation(like 'MyOp') overrides the given
+/// interface.
+Operation *op = ...;
+if (ExampleOpInterface example = dyn_cast<ExampleOpInterface>(op))
+ llvm::errs() << "num inputs = " << example.getNumInputs() << "\n";
+```
def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
//===----------------------------------------------------------------------===//
+// OpInterface definitions
+//===----------------------------------------------------------------------===//
+
+// NativeOpInterface corresponds to a specific 'OpInterface' class defined in
+// C++. The purpose to wrap around C++ symbol string with this class is to make
+// interfaces specified for ops in TableGen less alien and more integrated.
+class NativeOpInterface<string prop> : NativeOpTrait<""> {
+ // TODO(riverriddle) Remove when operation interfaces have their own trait
+ // subclass.
+ let trait = prop # "::Trait";
+}
+
+//===----------------------------------------------------------------------===//
// Op definitions
//===----------------------------------------------------------------------===//
traitID);
}
- /// Allow access to 'hasTrait'.
+ /// Returns an opaque pointer to a concept instance of the interface with the
+ /// given ID if one was registered to this operation.
+ static void *getRawInterface(ClassID *id) {
+ return InterfaceLookup::template lookup<Traits<ConcreteType>...>(id);
+ }
+
+ struct InterfaceLookup {
+ /// Trait to check if T provides a static 'getInterfaceID' method.
+ template <typename T, typename... Args>
+ using has_get_interface_id = decltype(T::getInterfaceID());
+
+ /// If 'T' is the same interface as 'interfaceID' return the concept
+ /// instance.
+ template <typename T>
+ static typename std::enable_if<is_detected<has_get_interface_id, T>::value,
+ void *>::type
+ lookup(ClassID *interfaceID) {
+ return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr;
+ }
+
+ /// 'T' is known to not be an interface, return nullptr.
+ template <typename T>
+ static typename std::enable_if<!is_detected<has_get_interface_id, T>::value,
+ void *>::type
+ lookup(ClassID *) {
+ return nullptr;
+ }
+
+ template <typename T, typename T2, typename... Ts>
+ static void *lookup(ClassID *interfaceID) {
+ auto *concept = lookup<T>(interfaceID);
+ return concept ? concept : lookup<T2, Ts...>(interfaceID);
+ }
+ };
+
+ /// Allow access to 'hasTrait' and 'getRawInterface'.
friend AbstractOperation;
};
+/// This class represents the base of an operation interface. Operation
+/// interfaces provide access to derived *Op properties through an opaquely
+/// Operation instance. Derived interfaces must also provide a 'Traits' class
+/// that defines a 'Concept' and a 'Model' class. The 'Concept' class defines an
+/// abstract virtual interface, where as the 'Model' class implements this
+/// interface for a specific derived *Op type. Both of these classes *must* not
+/// contain non-static data. A simple example is shown below:
+///
+/// struct ExampleOpInterfaceTraits {
+/// struct Concept {
+/// virtual unsigned getNumInputs(Operation *op) = 0;
+/// };
+/// template <typename OpT> class Model {
+/// unsigned getNumInputs(Operation *op) final {
+/// return llvm::cast<OpT>(op).getNumInputs();
+/// }
+/// };
+/// };
+///
+template <typename ConcreteType, typename Traits>
+class OpInterface : public Op<ConcreteType> {
+public:
+ using Concept = typename Traits::Concept;
+ template <typename T> using Model = typename Traits::template Model<T>;
+
+ OpInterface(Operation *op = nullptr)
+ : Op<ConcreteType>(op), impl(op ? getInterfaceFor(op) : nullptr) {
+ assert((!op || impl) &&
+ "instantiating an interface with an unregistered operation");
+ }
+
+ /// Support 'classof' by checking if the given operation defines the concrete
+ /// interface.
+ static bool classof(Operation *op) { return getInterfaceFor(op); }
+
+ /// Define an accessor for the ID of this interface.
+ static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
+
+ /// This is a special trait that registers a given interface with an
+ /// operation.
+ template <typename ConcreteOp>
+ struct Trait : public OpTrait::TraitBase<ConcreteOp, Trait> {
+ /// Define an accessor for the ID of this interface.
+ static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
+
+ /// Provide an accessor to a static instance of the interface model for the
+ /// concrete operation type.
+ /// The implementation is inspired from Sean Parent's concept-based
+ /// polymorphism. A key difference is that the set of classes erased is
+ /// statically known, which alleviates the need for using dynamic memory
+ /// allocation.
+ /// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
+ /// virtual table and generate a singleton object for each instantiation of
+ /// this class.
+ static Concept &instance() {
+ static Model<ConcreteOp> singleton;
+ return singleton;
+ }
+ };
+
+protected:
+ /// Get the raw concept in the correct derived concept type.
+ Concept *getImpl() { return impl; }
+
+private:
+ /// Returns the impl interface instance for the given operation.
+ static Concept *getInterfaceFor(Operation *op) {
+ // Access the raw interface from the abstract operation.
+ auto *abstractOp = op->getAbstractOperation();
+ return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
+ }
+
+ /// A pointer to the impl concept object.
+ Concept *impl;
+};
+
// These functions are out-of-line implementations of the methods in BinaryOp,
// which avoids them being template instantiated/duplicated.
namespace impl {
return opProperties & static_cast<OperationProperties>(property);
}
+ /// Returns an instance of the concept object for the given interface if it
+ /// was registered to this operation, null otherwise. This should not be used
+ /// directly.
+ template <typename T> typename T::Concept *getInterface() const {
+ return reinterpret_cast<typename T::Concept *>(
+ getRawInterface(T::getInterfaceID()));
+ }
+
/// Returns if the operation has a particular trait.
template <template <typename T> class Trait> bool hasTrait() const {
return hasRawTrait(ClassID::getID<Trait>());
return AbstractOperation(
T::getOperationName(), dialect, T::getOperationProperties(), T::classof,
T::parseAssembly, T::printAssembly, T::verifyInvariants, T::foldHook,
- T::getCanonicalizationPatterns, T::hasTrait);
+ T::getCanonicalizationPatterns, T::getRawInterface, T::hasTrait);
}
private:
SmallVectorImpl<OpFoldResult> &results),
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
MLIRContext *context),
+ void *(&getRawInterface)(ClassID *interfaceID),
bool (&hasTrait)(ClassID *traitID))
: name(name), dialect(dialect), classof(classof),
parseAssembly(parseAssembly), printAssembly(printAssembly),
verifyInvariants(verifyInvariants), foldHook(foldHook),
getCanonicalizationPatterns(getCanonicalizationPatterns),
- opProperties(opProperties), hasRawTrait(hasTrait) {}
+ opProperties(opProperties), getRawInterface(getRawInterface),
+ hasRawTrait(hasTrait) {}
/// The properties of the operation.
const OperationProperties opProperties;
+ /// Returns a raw instance of the concept for the given interface id if it is
+ /// registered to this operation, nullptr otherwise. This should not be used
+ /// directly.
+ void *(&getRawInterface)(ClassID *interfaceID);
+
/// This hook returns if the operation contains the trait corresponding
/// to the given ClassID.
bool (&hasRawTrait)(ClassID *traitID);
def ViewTraits : NativeOpTrait<"linalg::ViewTraits">;
+// The linalg 'LinalgLibraryInterface' provides access to the 'LinalgOp'
+// interface.
+def LinalgLibraryInterface : NativeOpInterface<"LinalgOp">;
+
// Base Tablegen class for Linalg ops.
// Linalg ops that correspond to library calls operate on linalg::View as their
// first operands. These may be optionally followed by non-view operands
// depending on the specific Linalg op.
class LinalgLibraryBase_Op<string mnemonic, list<OpTrait> props>
- : Op<Linalg_Dialect, mnemonic, !listconcat(props, [ViewTraits])> {
+ : Op<Linalg_Dialect, mnemonic,
+ !listconcat(props, [ViewTraits, LinalgLibraryInterface])> {
let parser = [{ return parseLinalgLibraryOp(parser, result); }];
let printer = [{ printLinalgLibraryOp(p, *this); }];
}
/// name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl`
std::string generateLibraryCallName(Operation *op);
-#define GET_OP_CLASSES
-#include "mlir/Linalg/IR/LinalgOps.h.inc"
-
-#define GET_OP_CLASSES
-#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc"
-
/// Returns the list of maps that map loops to operands of a Linalg op.
/// The i-th affine map identifies loop indices to subscripts that are used when
/// accessing the i-th operand.
/// Only permutation maps are currently supported.
SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
-llvm::raw_ostream &operator<<(llvm::raw_ostream &os, SubViewOp::Range &range);
-
-/// A LinalgOp behaves like a base class for the Linalg operations that are
-/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
-/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
-/// method and dispatches to the appropriate LinalgLibraryOp.
-/// This allows writing generic passes, like tiling, for all current and future
-/// LinalgOps without requiring templating and dispatch in multiple places.
-class LinalgOp : public Op<LinalgOp> {
-public:
- using Op::Op;
-
- LinalgOp(Operation *op) : Op<LinalgOp>(op) {
- impl = ModelDispatch<
-#define GET_OP_LIST
-#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
- >::dispatch(op);
- }
-
- static bool classof(Operation *op) {
- return ModelDispatch<
-#define GET_OP_LIST
-#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
- >::classof(op);
- }
-
- unsigned getNumParallelLoops() {
- return impl->getNumParallelLoops(getOperation());
- }
- unsigned getNumReductionLoops() {
- return impl->getNumReductionLoops(getOperation());
- }
- unsigned getNumWindowLoops() {
- return impl->getNumWindowLoops(getOperation());
- }
- unsigned getNumLoops() {
- return getNumParallelLoops() + getNumReductionLoops() + getNumWindowLoops();
- }
- unsigned getNumInputs() { return impl->getNumInputs(getOperation()); }
- unsigned getNumOutputs() { return impl->getNumOutputs(getOperation()); }
- unsigned getNumInputsAndOutputs() {
- return impl->getNumInputsAndOutputs(getOperation());
- }
- Value *getInput(unsigned i) { return impl->getInput(getOperation(), i); }
- llvm::Optional<unsigned> getIndexOfInput(Value *view) {
- return impl->getIndexOfInput(getOperation(), view);
- }
- ViewType getInputViewType(unsigned i) {
- return impl->getInputViewType(getOperation(), i);
- }
- Operation::operand_range getInputs() {
- return impl->getInputs(getOperation());
- }
- Value *getOutput(unsigned i) { return impl->getOutput(getOperation(), i); }
- llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
- return impl->getIndexOfOutput(getOperation(), view);
- }
- ViewType getOutputViewType(unsigned i) {
- return impl->getOutputViewType(getOperation(), i);
- }
- Operation::operand_range getOutputs() {
- return impl->getOutputs(getOperation());
- }
- Operation::operand_range getInputsAndOutputs() {
- return impl->getInputsAndOutputs(getOperation());
- }
- LinalgOp create(OpBuilder &builder, Location loc, ArrayRef<Value *> operands,
- ArrayRef<NamedAttribute> attributes) {
- return LinalgOp(impl->create(builder, loc, operands, attributes));
- }
-
-private:
+namespace detail {
+struct LinalgOpInterfaceTraits {
struct Concept {
virtual ~Concept() = default;
virtual unsigned getNumInputs(Operation *op) = 0;
ArrayRef<NamedAttribute> attributes) = 0;
};
- /// The implementation is inspired from Sean Parent's concept-based
- /// polymorphism. A key difference is that the set of classes erased is
- /// statically known, which alleviates the need for using dynamic memory
- /// allocation.
- /// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
- /// virtual table and generate a singleton object for each instantiation of
- /// this class.
- /// We pay the cost of initialization once on construction (find which class
- /// to dispatch to) and then a virtual dispatch on every call.
template <typename ConcreteOp> struct Model : public Concept {
- static Model<ConcreteOp> &instance() {
- static Model<ConcreteOp> singleton;
- return singleton;
- }
unsigned getNumInputs(Operation *op) override {
return cast<ConcreteOp>(op).getNumInputs();
}
attributes);
}
};
- Concept *impl;
-
- template <typename... Types> struct ModelDispatch;
+};
+} // namespace detail
- template <typename First, typename... Rest>
- struct ModelDispatch<First, Rest...> {
- static bool classof(Operation *op) {
- return isa<First>(op) || ModelDispatch<Rest...>::classof(op);
- }
- static Concept *dispatch(Operation *op) {
- return isa<First>(op) ? &Model<First>::instance()
- : ModelDispatch<Rest...>::dispatch(op);
- }
- };
+/// A LinalgOp behaves like a base class for the Linalg operations that are
+/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
+/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
+/// method and dispatches to the appropriate LinalgLibraryOp.
+/// This allows writing generic passes, like tiling, for all current and future
+/// LinalgOps without requiring templating and dispatch in multiple places.
+class LinalgOp : public OpInterface<LinalgOp, detail::LinalgOpInterfaceTraits> {
+public:
+ using OpInterface<LinalgOp, detail::LinalgOpInterfaceTraits>::OpInterface;
- template <typename...> struct ModelDispatch {
- static bool classof(Operation *op) { return false; }
- static Concept *dispatch(Operation *op) {
- llvm_unreachable("Invalid LinalgOp");
- }
- };
+ unsigned getNumParallelLoops() {
+ return getImpl()->getNumParallelLoops(getOperation());
+ }
+ unsigned getNumReductionLoops() {
+ return getImpl()->getNumReductionLoops(getOperation());
+ }
+ unsigned getNumWindowLoops() {
+ return getImpl()->getNumWindowLoops(getOperation());
+ }
+ unsigned getNumLoops() {
+ return getNumParallelLoops() + getNumReductionLoops() + getNumWindowLoops();
+ }
+ unsigned getNumInputs() { return getImpl()->getNumInputs(getOperation()); }
+ unsigned getNumOutputs() { return getImpl()->getNumOutputs(getOperation()); }
+ unsigned getNumInputsAndOutputs() {
+ return getImpl()->getNumInputsAndOutputs(getOperation());
+ }
+ Value *getInput(unsigned i) { return getImpl()->getInput(getOperation(), i); }
+ llvm::Optional<unsigned> getIndexOfInput(Value *view) {
+ return getImpl()->getIndexOfInput(getOperation(), view);
+ }
+ ViewType getInputViewType(unsigned i) {
+ return getImpl()->getInputViewType(getOperation(), i);
+ }
+ Operation::operand_range getInputs() {
+ return getImpl()->getInputs(getOperation());
+ }
+ Value *getOutput(unsigned i) {
+ return getImpl()->getOutput(getOperation(), i);
+ }
+ llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
+ return getImpl()->getIndexOfOutput(getOperation(), view);
+ }
+ ViewType getOutputViewType(unsigned i) {
+ return getImpl()->getOutputViewType(getOperation(), i);
+ }
+ Operation::operand_range getOutputs() {
+ return getImpl()->getOutputs(getOperation());
+ }
+ Operation::operand_range getInputsAndOutputs() {
+ return getImpl()->getInputsAndOutputs(getOperation());
+ }
+ LinalgOp create(OpBuilder &builder, Location loc, ArrayRef<Value *> operands,
+ ArrayRef<NamedAttribute> attributes) {
+ return LinalgOp(getImpl()->create(builder, loc, operands, attributes));
+ }
};
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgOps.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc"
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, SubViewOp::Range &range);
+
} // namespace linalg
} // namespace mlir