[mlir] Optimize Op definitions and registration to optimize for code size
authorRiver Riddle <riddleriver@gmail.com>
Mon, 2 Nov 2020 22:21:02 +0000 (14:21 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Mon, 2 Nov 2020 22:39:43 +0000 (14:39 -0800)
This revision refactors the base Op/AbstractOperation classes to reduce the amount of generated code size when defining a new operation. The current scheme involves taking the address of functions defined directly on Op and Trait classes. This is problematic because even when these functions are empty/unused we still result in these functions being defined in the main executable. In this revision, we switch to using SFINAE and template type filtering to remove remove functions that are not needed/used. For example, if an operation does not define a custom `print` method we shouldn't define a templated `printAssembly` method for it. The same applies to parsing/folding/verification/etc. This dropped MLIR code size for a large downstream library by ~10%(~1 mb in an opt build).

Differential Revision: https://reviews.llvm.org/D90196

mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp

index a3328c2..cd64d38 100644 (file)
@@ -150,12 +150,9 @@ protected:
   /// This method is used by derived classes to add their operations to the set.
   ///
   template <typename... Args> void addOperations() {
-    (void)std::initializer_list<int>{0, (addOperation<Args>(), 0)...};
-  }
-  template <typename Arg> void addOperation() {
-    addOperation(AbstractOperation::get<Arg>(*this));
+    (void)std::initializer_list<int>{
+        0, (AbstractOperation::insert<Args>(*this), 0)...};
   }
-  void addOperation(AbstractOperation opInfo);
 
   /// Register a set of type classes with this dialect.
   template <typename... Args> void addTypes() {
index a9f706a..f1be187 100644 (file)
@@ -105,6 +105,12 @@ public:
   /// Return the operation that this refers to.
   Operation *getOperation() { return state; }
 
+  /// Return the dialect that this refers to.
+  Dialect *getDialect() { return getOperation()->getDialect(); }
+
+  /// Return the parent Region of this operation.
+  Region *getParentRegion() { return getOperation()->getParentRegion(); }
+
   /// Returns the closest surrounding operation that contains this operation
   /// or nullptr if this is a top-level operation.
   Operation *getParentOp() { return getOperation()->getParentOp(); }
@@ -238,7 +244,7 @@ protected:
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
 
   // The fallback for the printer is to print it the generic assembly form.
-  void print(OpAsmPrinter &p);
+  static void print(Operation *op, OpAsmPrinter &p);
 
   /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
   /// so we can cast it away here.
@@ -246,6 +252,9 @@ protected:
 
 private:
   Operation *state;
+
+  /// Allow access to internal hook implementation methods.
+  friend AbstractOperation;
 };
 
 // Allow comparing operators.
@@ -267,117 +276,6 @@ inline raw_ostream &operator<<(raw_ostream &os, OpState &op) {
   return os;
 }
 
-/// This template defines the foldHook as used by AbstractOperation.
-///
-/// The default implementation uses a general fold method that can be defined on
-/// custom ops which can return multiple results.
-template <typename ConcreteType, bool isSingleResult, typename = void>
-class FoldingHook {
-public:
-  /// This is an implementation detail of the constant folder hook for
-  /// AbstractOperation.
-  static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
-                                SmallVectorImpl<OpFoldResult> &results) {
-    auto operationFoldResult = cast<ConcreteType>(op).fold(operands, results);
-    // Failure to fold or in place fold both mean we can continue folding.
-    if (failed(operationFoldResult) || results.empty()) {
-      auto traitFoldResult = ConcreteType::foldTraits(op, operands, results);
-      // Only return the trait fold result if it is a success since
-      // operationFoldResult might have been a success originally.
-      if (succeeded(traitFoldResult))
-        return traitFoldResult;
-    }
-    return operationFoldResult;
-  }
-
-  /// This hook implements a generalized folder for this operation.  Operations
-  /// can implement this to provide simplifications rules that are applied by
-  /// the Builder::createOrFold API and the canonicalization pass.
-  ///
-  /// This is an intentionally limited interface - implementations of this hook
-  /// can only perform the following changes to the operation:
-  ///
-  ///  1. They can leave the operation alone and without changing the IR, and
-  ///     return failure.
-  ///  2. They can mutate the operation in place, without changing anything else
-  ///     in the IR.  In this case, return success.
-  ///  3. They can return a list of existing values that can be used instead of
-  ///     the operation.  In this case, fill in the results list and return
-  ///     success.  The caller will remove the operation and use those results
-  ///     instead.
-  ///
-  /// This allows expression of some simple in-place canonicalizations (e.g.
-  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
-  /// generalized constant folding.
-  ///
-  /// If not overridden, this fallback implementation always fails to fold.
-  ///
-  LogicalResult fold(ArrayRef<Attribute> operands,
-                     SmallVectorImpl<OpFoldResult> &results) {
-    return failure();
-  }
-};
-
-/// This template specialization defines the foldHook as used by
-/// AbstractOperation for single-result operations.  This gives the hook a nicer
-/// signature that is easier to implement.
-template <typename ConcreteType, bool isSingleResult>
-class FoldingHook<ConcreteType, isSingleResult,
-                  typename std::enable_if<isSingleResult>::type> {
-public:
-  /// If the operation returns a single value, then the Op can be implicitly
-  /// converted to an Value.  This yields the value of the only result.
-  operator Value() {
-    return static_cast<ConcreteType *>(this)->getOperation()->getResult(0);
-  }
-
-  /// This is an implementation detail of the constant folder hook for
-  /// AbstractOperation.
-  static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
-                                SmallVectorImpl<OpFoldResult> &results) {
-    auto result = cast<ConcreteType>(op).fold(operands);
-    // Failure to fold or in place fold both mean we can continue folding.
-    if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
-      // Only consider the trait fold result if it is a success since
-      // the operation fold might have been a success originally.
-      if (auto traitFoldResult = ConcreteType::foldTraits(op, operands))
-        result = traitFoldResult;
-    }
-
-    if (!result)
-      return failure();
-
-    // Check if the operation was folded in place. In this case, the operation
-    // returns itself.
-    if (result.template dyn_cast<Value>() != op->getResult(0))
-      results.push_back(result);
-    return success();
-  }
-
-  /// This hook implements a generalized folder for this operation.  Operations
-  /// can implement this to provide simplifications rules that are applied by
-  /// the Builder::createOrFold API and the canonicalization pass.
-  ///
-  /// This is an intentionally limited interface - implementations of this hook
-  /// can only perform the following changes to the operation:
-  ///
-  ///  1. They can leave the operation alone and without changing the IR, and
-  ///     return nullptr.
-  ///  2. They can mutate the operation in place, without changing anything else
-  ///     in the IR.  In this case, return the operation itself.
-  ///  3. They can return an existing SSA value that can be used instead of
-  ///     the operation.  In this case, return that value.  The caller will
-  ///     remove the operation and use that result instead.
-  ///
-  /// This allows expression of some simple in-place canonicalizations (e.g.
-  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
-  /// generalized constant folding.
-  ///
-  /// If not overridden, this fallback implementation always fails to fold.
-  ///
-  OpFoldResult fold(ArrayRef<Attribute> operands) { return {}; }
-};
-
 //===----------------------------------------------------------------------===//
 // Operation Trait Types
 //===----------------------------------------------------------------------===//
@@ -441,30 +339,6 @@ protected:
     auto *base = static_cast<OpState *>(concrete);
     return base->getOperation();
   }
-
-  /// Provide default implementations of trait hooks.  This allows traits to
-  /// provide exactly the overrides they care about.
-  static LogicalResult verifyTrait(Operation *op) { return success(); }
-  static AbstractOperation::OperationProperties getTraitProperties() {
-    return 0;
-  }
-
-  static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
-    SmallVector<OpFoldResult, 1> results;
-    if (failed(foldTrait(op, operands, results)))
-      return {};
-    if (results.empty())
-      return op->getResult(0);
-    assert(results.size() == 1 &&
-           "Single result op cannot return multiple fold results");
-
-    return results[0];
-  }
-
-  static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
-                                 SmallVectorImpl<OpFoldResult> &results) {
-    return failure();
-  }
 };
 
 //===----------------------------------------------------------------------===//
@@ -738,6 +612,10 @@ public:
   Value getResult() { return this->getOperation()->getResult(0); }
   Type getType() { return getResult().getType(); }
 
+  /// If the operation returns a single value, then the Op can be implicitly
+  /// converted to an Value. This yields the value of the only result.
+  operator Value() { return getResult(); }
+
   /// Replace all uses of 'this' value with the new value, updating anything in
   /// the IR that uses 'this' to use the other value instead.  When this returns
   /// there are zero uses of 'this'.
@@ -1307,6 +1185,170 @@ struct MemRefsNormalizable
 } // end namespace OpTrait
 
 //===----------------------------------------------------------------------===//
+// Internal Trait Utilities
+//===----------------------------------------------------------------------===//
+
+namespace op_definition_impl {
+//===----------------------------------------------------------------------===//
+// Trait Existence
+
+/// Returns true if this given Trait ID matches the IDs of any of the provided
+/// trait types `Traits`.
+template <template <typename T> class... Traits>
+static bool hasTrait(TypeID traitID) {
+  TypeID traitIDs[] = {TypeID::get<Traits>()...};
+  for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
+    if (traitIDs[i] == traitID)
+      return true;
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Trait Folding
+
+/// Trait to check if T provides a 'foldTrait' method for single result
+/// operations.
+template <typename T, typename... Args>
+using has_single_result_fold_trait = decltype(T::foldTrait(
+    std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
+template <typename T>
+using detect_has_single_result_fold_trait =
+    llvm::is_detected<has_single_result_fold_trait, T>;
+/// Trait to check if T provides a general 'foldTrait' method.
+template <typename T, typename... Args>
+using has_fold_trait =
+    decltype(T::foldTrait(std::declval<Operation *>(),
+                          std::declval<ArrayRef<Attribute>>(),
+                          std::declval<SmallVectorImpl<OpFoldResult> &>()));
+template <typename T>
+using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
+/// Trait to check if T provides any `foldTrait` method.
+/// NOTE: This should use std::disjunction when C++17 is available.
+template <typename T>
+using detect_has_any_fold_trait =
+    std::conditional_t<bool(detect_has_fold_trait<T>::value),
+                       detect_has_fold_trait<T>,
+                       detect_has_single_result_fold_trait<T>>;
+
+/// Returns the result of folding a trait that implements a `foldTrait` function
+/// that is specialized for operations that have a single result.
+template <typename Trait>
+static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
+                        LogicalResult>
+foldTrait(Operation *op, ArrayRef<Attribute> operands,
+          SmallVectorImpl<OpFoldResult> &results) {
+  assert(op->hasTrait<OpTrait::OneResult>() &&
+         "expected trait on non single-result operation to implement the "
+         "general `foldTrait` method");
+  // If a previous trait has already been folded and replaced this operation, we
+  // fail to fold this trait.
+  if (!results.empty())
+    return failure();
+
+  if (OpFoldResult result = Trait::foldTrait(op, operands)) {
+    if (result.template dyn_cast<Value>() != op->getResult(0))
+      results.push_back(result);
+    return success();
+  }
+  return failure();
+}
+/// Returns the result of folding a trait that implements a generalized
+/// `foldTrait` function that is supports any operation type.
+template <typename Trait>
+static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
+foldTrait(Operation *op, ArrayRef<Attribute> operands,
+          SmallVectorImpl<OpFoldResult> &results) {
+  // If a previous trait has already been folded and replaced this operation, we
+  // fail to fold this trait.
+  return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
+}
+
+/// The internal implementation of `foldTraits` below that returns the result of
+/// folding a set of trait types `Ts` that implement a `foldTrait` method.
+template <typename... Ts>
+static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
+                                    SmallVectorImpl<OpFoldResult> &results,
+                                    std::tuple<Ts...> *) {
+  bool anyFolded = false;
+  (void)std::initializer_list<int>{
+      (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
+  return success(anyFolded);
+}
+
+/// Given a tuple type containing a set of traits that contain a `foldTrait`
+/// method, return the result of folding the given operation.
+template <typename TraitTupleT>
+static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
+foldTraits(Operation *op, ArrayRef<Attribute> operands,
+           SmallVectorImpl<OpFoldResult> &results) {
+  return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
+}
+/// A variant of the method above that is specialized when there are no traits
+/// that contain a `foldTrait` method.
+template <typename TraitTupleT>
+static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
+foldTraits(Operation *op, ArrayRef<Attribute> operands,
+           SmallVectorImpl<OpFoldResult> &results) {
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// Trait Properties
+
+/// Trait to check if T provides a `getTraitProperties` method.
+template <typename T, typename... Args>
+using has_get_trait_properties = decltype(T::getTraitProperties());
+template <typename T>
+using detect_has_get_trait_properties =
+    llvm::is_detected<has_get_trait_properties, T>;
+
+/// The internal implementation of `getTraitProperties` below that returns the
+/// OR of invoking `getTraitProperties` on all of the provided trait types `Ts`.
+template <typename... Ts>
+static AbstractOperation::OperationProperties
+getTraitPropertiesImpl(std::tuple<Ts...> *) {
+  AbstractOperation::OperationProperties result = 0;
+  (void)std::initializer_list<int>{(result |= Ts::getTraitProperties(), 0)...};
+  return result;
+}
+
+/// Given a tuple type containing a set of traits that contain a
+/// `getTraitProperties` method, return the OR of all of the results of invoking
+/// those methods.
+template <typename TraitTupleT>
+static AbstractOperation::OperationProperties getTraitProperties() {
+  return getTraitPropertiesImpl((TraitTupleT *)nullptr);
+}
+
+//===----------------------------------------------------------------------===//
+// Trait Verification
+
+/// Trait to check if T provides a `verifyTrait` method.
+template <typename T, typename... Args>
+using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
+template <typename T>
+using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
+
+/// The internal implementation of `verifyTraits` below that returns the result
+/// of verifying the current operation with all of the provided trait types
+/// `Ts`.
+template <typename... Ts>
+static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
+  LogicalResult result = success();
+  (void)std::initializer_list<int>{
+      (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
+  return result;
+}
+
+/// Given a tuple type containing a set of traits that contain a
+/// `verifyTrait` method, return the result of verifying the given operation.
+template <typename TraitTupleT>
+static LogicalResult verifyTraits(Operation *op) {
+  return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
+}
+} // namespace op_definition_impl
+
+//===----------------------------------------------------------------------===//
 // Operation Definition classes
 //===----------------------------------------------------------------------===//
 
@@ -1314,21 +1356,17 @@ struct MemRefsNormalizable
 /// argument 'ConcreteType' should be the concrete type by CRTP and the others
 /// are base classes by the policy pattern.
 template <typename ConcreteType, template <typename T> class... Traits>
-class Op : public OpState,
-           public Traits<ConcreteType>...,
-           public FoldingHook<ConcreteType,
-                              llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
-                                              Traits<ConcreteType>...>::value> {
+class Op : public OpState, public Traits<ConcreteType>... {
 public:
+  /// Inherit getOperation from `OpState`.
+  using OpState::getOperation;
+
   /// Return if this operation contains the provided trait.
   template <template <typename T> class Trait>
   static constexpr bool hasTrait() {
     return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
   }
 
-  /// Return the operation that this refers to.
-  Operation *getOperation() { return OpState::getOperation(); }
-
   /// Create a deep copy of this operation.
   ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
 
@@ -1339,12 +1377,6 @@ public:
     return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
   }
 
-  /// Return the dialect that this refers to.
-  Dialect *getDialect() { return getOperation()->getDialect(); }
-
-  /// Return the parent Region of this operation.
-  Region *getParentRegion() { return getOperation()->getParentRegion(); }
-
   /// Return true if this "op class" can match against the specified operation.
   static bool classof(Operation *op) {
     if (auto *abstractOp = op->getAbstractOperation())
@@ -1358,56 +1390,6 @@ public:
     return false;
   }
 
-  /// This is the hook used by the AsmParser to parse the custom form of this
-  /// op from an .mlir file.  Op implementations should provide a parse method,
-  /// which returns failure.  On success, they should return fill in result with
-  /// the fields to use.
-  static ParseResult parseAssembly(OpAsmParser &parser,
-                                   OperationState &result) {
-    return ConcreteType::parse(parser, result);
-  }
-
-  /// This is the hook used by the AsmPrinter to emit this to the .mlir file.
-  /// Op implementations should provide a print method.
-  static void printAssembly(Operation *op, OpAsmPrinter &p) {
-    auto opPointer = dyn_cast<ConcreteType>(op);
-    assert(opPointer &&
-           "op's name does not match name of concrete type instantiated with");
-    opPointer.print(p);
-  }
-
-  /// This is the hook that checks whether or not this operation is well
-  /// formed according to the invariants of its opcode.  It delegates to the
-  /// Traits for their policy implementations, and allows the user to specify
-  /// their own verify() method.
-  ///
-  /// On success this returns false; on failure it emits an error to the
-  /// diagnostic subsystem and returns true.
-  static LogicalResult verifyInvariants(Operation *op) {
-    return failure(
-        failed(BaseVerifier<Traits<ConcreteType>...>::verifyTrait(op)) ||
-        failed(cast<ConcreteType>(op).verify()));
-  }
-
-  /// This is the hook that tries to fold the given operation according to its
-  /// traits. It delegates to the Traits for their policy implementations, and
-  /// allows the user to specify their own fold() method.
-  static OpFoldResult foldTraits(Operation *op, ArrayRef<Attribute> operands) {
-    return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands);
-  }
-
-  static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
-                                  SmallVectorImpl<OpFoldResult> &results) {
-    return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands,
-                                                           results);
-  }
-
-  // Returns the properties of an operation by combining the properties of the
-  // traits of the op.
-  static AbstractOperation::OperationProperties getOperationProperties() {
-    return BaseProperties<Traits<ConcreteType>...>::getTraitProperties();
-  }
-
   /// Expose the type we are instantiated on to template machinery that may want
   /// to introspect traits on this operation.
   using ConcreteOpType = ConcreteType;
@@ -1430,95 +1412,166 @@ public:
   }
 
 private:
-  template <typename... Types> struct BaseVerifier;
-
-  template <typename First, typename... Rest>
-  struct BaseVerifier<First, Rest...> {
-    static LogicalResult verifyTrait(Operation *op) {
-      return failure(failed(First::verifyTrait(op)) ||
-                     failed(BaseVerifier<Rest...>::verifyTrait(op)));
-    }
-  };
-
-  template <typename...> struct BaseVerifier {
-    static LogicalResult verifyTrait(Operation *op) { return success(); }
-  };
-
-  template <typename... Types> struct BaseProperties;
-
-  template <typename First, typename... Rest>
-  struct BaseProperties<First, Rest...> {
-    static AbstractOperation::OperationProperties getTraitProperties() {
-      return First::getTraitProperties() |
-             BaseProperties<Rest...>::getTraitProperties();
-    }
-  };
-
-  template <typename... Types>
-  struct BaseFolder;
-
-  template <typename First, typename... Rest>
-  struct BaseFolder<First, Rest...> {
-    static OpFoldResult foldTraits(Operation *op,
-                                   ArrayRef<Attribute> operands) {
-      auto result = First::foldTrait(op, operands);
-      // Failure to fold or in place fold both mean we can continue folding.
-      if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
-        // Only consider the trait fold result if it is a success since
-        // the operation fold might have been a success originally.
-        auto resultRemaining = BaseFolder<Rest...>::foldTraits(op, operands);
-        if (resultRemaining)
-          result = resultRemaining;
-      }
-
-      return result;
-    }
-
-    static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
-                                    SmallVectorImpl<OpFoldResult> &results) {
-      auto result = First::foldTrait(op, operands, results);
-      // Failure to fold or in place fold both mean we can continue folding.
-      if (failed(result) || results.empty()) {
-        auto resultRemaining =
-            BaseFolder<Rest...>::foldTraits(op, operands, results);
-        if (succeeded(resultRemaining))
-          result = resultRemaining;
-      }
+  /// Trait to check if T provides a 'fold' method for a single result op.
+  template <typename T, typename... Args>
+  using has_single_result_fold =
+      decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
+  template <typename T>
+  using detect_has_single_result_fold =
+      llvm::is_detected<has_single_result_fold, T>;
+  /// Trait to check if T provides a general 'fold' method.
+  template <typename T, typename... Args>
+  using has_fold = decltype(
+      std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
+                             std::declval<SmallVectorImpl<OpFoldResult> &>()));
+  template <typename T> using detect_has_fold = llvm::is_detected<has_fold, T>;
+  /// Trait to check if T provides a 'print' method.
+  template <typename T, typename... Args>
+  using has_print =
+      decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
+  template <typename T>
+  using detect_has_print = llvm::is_detected<has_print, T>;
+  /// A tuple type containing the traits that have a `foldTrait` function.
+  using FoldableTraitsTupleT = typename detail::FilterTypes<
+      op_definition_impl::detect_has_any_fold_trait,
+      Traits<ConcreteType>...>::type;
+  /// A tuple type containing the traits that have a verify function.
+  using VerifiableTraitsTupleT =
+      typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
+                                   Traits<ConcreteType>...>::type;
+
+  /// Returns the properties of this operation by combining the properties
+  /// defined by the traits.
+  static AbstractOperation::OperationProperties getOperationProperties() {
+    return op_definition_impl::getTraitProperties<typename detail::FilterTypes<
+        op_definition_impl::detect_has_get_trait_properties,
+        Traits<ConcreteType>...>::type>();
+  }
 
-      return result;
-    }
-  };
+  /// Returns an interface map containing the interfaces registered to this
+  /// operation.
+  static detail::InterfaceMap getInterfaceMap() {
+    return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
+  }
 
-  template <typename...>
-  struct BaseFolder {
-    static OpFoldResult foldTraits(Operation *op,
-                                   ArrayRef<Attribute> operands) {
-      return {};
-    }
-    static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
-                                    SmallVectorImpl<OpFoldResult> &results) {
-      return failure();
+  /// Return the internal implementations of each of the AbstractOperation
+  /// hooks.
+  /// Implementation of `FoldHookFn` AbstractOperation hook.
+  static AbstractOperation::FoldHookFn getFoldHookFn() {
+    return getFoldHookFnImpl<ConcreteType>();
+  }
+  /// The internal implementation of `getFoldHookFn` above that is invoked if
+  /// the operation is single result and defines a `fold` method.
+  template <typename ConcreteOpT>
+  static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
+                                          Traits<ConcreteOpT>...>::value &&
+                              detect_has_single_result_fold<ConcreteOpT>::value,
+                          AbstractOperation::FoldHookFn>
+  getFoldHookFnImpl() {
+    return &foldSingleResultHook<ConcreteOpT>;
+  }
+  /// The internal implementation of `getFoldHookFn` above that is invoked if
+  /// the operation is not single result and defines a `fold` method.
+  template <typename ConcreteOpT>
+  static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
+                                           Traits<ConcreteOpT>...>::value &&
+                              detect_has_fold<ConcreteOpT>::value,
+                          AbstractOperation::FoldHookFn>
+  getFoldHookFnImpl() {
+    return &foldHook<ConcreteOpT>;
+  }
+  /// The internal implementation of `getFoldHookFn` above that is invoked if
+  /// the operation does not define a `fold` method.
+  template <typename ConcreteOpT>
+  static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
+                              !detect_has_fold<ConcreteOpT>::value,
+                          AbstractOperation::FoldHookFn>
+  getFoldHookFnImpl() {
+    // In this case, we only need to fold the traits of the operation.
+    return &op_definition_impl::foldTraits<FoldableTraitsTupleT>;
+  }
+  /// Return the result of folding a single result operation that defines a
+  /// `fold` method.
+  template <typename ConcreteOpT>
+  static LogicalResult
+  foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
+                       SmallVectorImpl<OpFoldResult> &results) {
+    OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
+
+    // If the fold failed or was in-place, try to fold the traits of the
+    // operation.
+    if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
+      if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
+              op, operands, results)))
+        return success();
+      return success(static_cast<bool>(result));
     }
-  };
+    results.push_back(result);
+    return success();
+  }
+  /// Return the result of folding an operation that defines a `fold` method.
+  template <typename ConcreteOpT>
+  static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+                                SmallVectorImpl<OpFoldResult> &results) {
+    LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
 
-  template <typename...> struct BaseProperties {
-    static AbstractOperation::OperationProperties getTraitProperties() {
-      return 0;
+    // If the fold failed or was in-place, try to fold the traits of the
+    // operation.
+    if (failed(result) || results.empty()) {
+      if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
+              op, operands, results)))
+        return success();
     }
-  };
-
-  /// Returns true if this operation contains the trait for the given typeID.
-  static bool hasTrait(TypeID traitID) {
-    return llvm::is_contained(llvm::makeArrayRef({TypeID::get<Traits>()...}),
-                              traitID);
+    return result;
+  }
+
+  /// Implementation of `GetCanonicalizationPatternsFn` AbstractOperation hook.
+  static AbstractOperation::GetCanonicalizationPatternsFn
+  getGetCanonicalizationPatternsFn() {
+    return &ConcreteType::getCanonicalizationPatterns;
+  }
+  /// Implementation of `GetHasTraitFn`
+  static AbstractOperation::HasTraitFn getHasTraitFn() {
+    return &op_definition_impl::hasTrait<Traits...>;
+  }
+  /// Implementation of `ParseAssemblyFn` AbstractOperation hook.
+  static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
+    return &ConcreteType::parse;
+  }
+  /// Implementation of `PrintAssemblyFn` AbstractOperation hook.
+  static AbstractOperation::PrintAssemblyFn getPrintAssemblyFn() {
+    return getPrintAssemblyFnImpl<ConcreteType>();
+  }
+  /// The internal implementation of `getPrintAssemblyFn` that is invoked when
+  /// the concrete operation does not define a `print` method.
+  template <typename ConcreteOpT>
+  static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
+                          AbstractOperation::PrintAssemblyFn>
+  getPrintAssemblyFnImpl() {
+    return &OpState::print;
+  }
+  /// The internal implementation of `getPrintAssemblyFn` that is invoked when
+  /// the concrete operation defines a `print` method.
+  template <typename ConcreteOpT>
+  static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
+                          AbstractOperation::PrintAssemblyFn>
+  getPrintAssemblyFnImpl() {
+    return &printAssembly;
   }
-
-  /// Returns an interface map for the interfaces registered to this operation.
-  static detail::InterfaceMap getInterfaceMap() {
-    return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
+  static void printAssembly(Operation *op, OpAsmPrinter &p) {
+    return cast<ConcreteType>(op).print(p);
+  }
+  /// Implementation of `VerifyInvariantsFn` AbstractOperation hook.
+  static AbstractOperation::VerifyInvariantsFn getVerifyInvariantsFn() {
+    return &verifyInvariants;
+  }
+  static LogicalResult verifyInvariants(Operation *op) {
+    return failure(
+        failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
+        failed(cast<ConcreteType>(op).verify()));
   }
 
-  /// Allow access to 'hasTrait' and 'getInterfaceMap'.
+  /// Allow access to internal implementation methods.
   friend AbstractOperation;
 };
 
index f2ee83d..563a906 100644 (file)
@@ -80,6 +80,15 @@ class AbstractOperation {
 public:
   using OperationProperties = uint32_t;
 
+  using GetCanonicalizationPatternsFn = void (*)(OwningRewritePatternList &,
+                                                 MLIRContext *);
+  using FoldHookFn = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
+                                       SmallVectorImpl<OpFoldResult> &);
+  using HasTraitFn = bool (*)(TypeID);
+  using ParseAssemblyFn = ParseResult (*)(OpAsmParser &, OperationState &);
+  using PrintAssemblyFn = void (*)(Operation *, OpAsmPrinter &);
+  using VerifyInvariantsFn = LogicalResult (*)(Operation *);
+
   /// This is the name of the operation.
   const Identifier name;
 
@@ -90,15 +99,19 @@ public:
   TypeID typeID;
 
   /// Use the specified object to parse this ops custom assembly format.
-  ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result);
+  ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
 
   /// This hook implements the AsmPrinter for this operation.
-  void (&printAssembly)(Operation *op, OpAsmPrinter &p);
+  void printAssembly(Operation *op, OpAsmPrinter &p) const {
+    return printAssemblyFn(op, p);
+  }
 
   /// This hook implements the verifier for this operation.  It should emits an
   /// error message and returns failure if a problem is detected, or returns
   /// success if everything is ok.
-  LogicalResult (&verifyInvariants)(Operation *op);
+  LogicalResult verifyInvariants(Operation *op) const {
+    return verifyInvariantsFn(op);
+  }
 
   /// This hook implements a generalized folder for this operation.  Operations
   /// can implement this to provide simplifications rules that are applied by
@@ -119,13 +132,17 @@ public:
   /// This allows expression of some simple in-place canonicalizations (e.g.
   /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
   /// generalized constant folding.
-  LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
-                            SmallVectorImpl<OpFoldResult> &results);
+  LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+                         SmallVectorImpl<OpFoldResult> &results) const {
+    return foldHookFn(op, operands, results);
+  }
 
   /// This hook returns any canonicalization pattern rewrites that the operation
   /// supports, for use by the canonicalization pass.
-  void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
-                                      MLIRContext *context);
+  void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                   MLIRContext *context) const {
+    return getCanonicalizationPatternsFn(results, context);
+  }
 
   /// Returns whether the operation has a particular property.
   bool hasProperty(OperationProperty property) const {
@@ -141,7 +158,7 @@ public:
 
   /// Returns true if the operation has a particular trait.
   template <template <typename T> class Trait> bool hasTrait() const {
-    return hasRawTrait(TypeID::get<Trait>());
+    return hasTraitFn(TypeID::get<Trait>());
   }
 
   /// Look up the specified operation in the specified MLIRContext and return a
@@ -151,26 +168,30 @@ public:
 
   /// This constructor is used by Dialect objects when they register the list of
   /// operations they contain.
-  template <typename T> static AbstractOperation get(Dialect &dialect) {
-    return AbstractOperation(
-        T::getOperationName(), dialect, T::getOperationProperties(),
-        TypeID::get<T>(), T::parseAssembly, T::printAssembly,
-        T::verifyInvariants, T::foldHook, T::getCanonicalizationPatterns,
-        T::getInterfaceMap(), T::hasTrait);
+  template <typename T> static void insert(Dialect &dialect) {
+    insert(T::getOperationName(), dialect, T::getOperationProperties(),
+           TypeID::get<T>(), T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
+           T::getVerifyInvariantsFn(), T::getFoldHookFn(),
+           T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
+           T::getHasTraitFn());
   }
 
 private:
-  AbstractOperation(
-      StringRef name, Dialect &dialect, OperationProperties opProperties,
-      TypeID typeID,
-      ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result),
-      void (&printAssembly)(Operation *op, OpAsmPrinter &p),
-      LogicalResult (&verifyInvariants)(Operation *op),
-      LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
-                                SmallVectorImpl<OpFoldResult> &results),
-      void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
-                                          MLIRContext *context),
-      detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID));
+  static void insert(StringRef name, Dialect &dialect,
+                     OperationProperties opProperties, TypeID typeID,
+                     ParseAssemblyFn parseAssembly,
+                     PrintAssemblyFn printAssembly,
+                     VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
+                     GetCanonicalizationPatternsFn getCanonicalizationPatterns,
+                     detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
+
+  AbstractOperation(StringRef name, Dialect &dialect,
+                    OperationProperties opProperties, TypeID typeID,
+                    ParseAssemblyFn parseAssembly,
+                    PrintAssemblyFn printAssembly,
+                    VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
+                    GetCanonicalizationPatternsFn getCanonicalizationPatterns,
+                    detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
 
   /// The properties of the operation.
   const OperationProperties opProperties;
@@ -178,9 +199,13 @@ private:
   /// A map of interfaces that were registered to this operation.
   detail::InterfaceMap interfaceMap;
 
-  /// This hook returns if the operation contains the trait corresponding
-  /// to the given TypeID.
-  bool (&hasRawTrait)(TypeID traitID);
+  /// Internal callback hooks provided by the op implementation.
+  FoldHookFn foldHookFn;
+  GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+  HasTraitFn hasTraitFn;
+  ParseAssemblyFn parseAssemblyFn;
+  PrintAssemblyFn printAssemblyFn;
+  VerifyInvariantsFn verifyInvariantsFn;
 };
 
 //===----------------------------------------------------------------------===//
index cea9d47..284cf2a 100644 (file)
@@ -621,22 +621,6 @@ bool MLIRContext::isOperationRegistered(StringRef name) {
   return impl->registeredOperations.count(name);
 }
 
-void Dialect::addOperation(AbstractOperation opInfo) {
-  assert((getNamespace().empty() || opInfo.dialect.name == getNamespace()) &&
-         "op name doesn't start with dialect namespace");
-  assert(&opInfo.dialect == this && "Dialect object mismatch");
-  auto &impl = context->getImpl();
-  assert(impl.multiThreadedExecutionContext == 0 &&
-         "Registering a new operation kind while in a multi-threaded execution "
-         "context");
-  StringRef opName = opInfo.name;
-  if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
-    llvm::errs() << "error: operation named '" << opInfo.name
-                 << "' is already registered.\n";
-    abort();
-  }
-}
-
 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
   auto &impl = context->getImpl();
   assert(impl.multiThreadedExecutionContext == 0 &&
@@ -661,6 +645,10 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
     llvm::report_fatal_error("Dialect Attribute already registered.");
 }
 
+//===----------------------------------------------------------------------===//
+// AbstractAttribute
+//===----------------------------------------------------------------------===//
+
 /// Get the dialect that registered the attribute with the provided typeid.
 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
                                                    MLIRContext *context) {
@@ -672,8 +660,17 @@ const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
   return *it->second;
 }
 
+//===----------------------------------------------------------------------===//
+// AbstractOperation
+//===----------------------------------------------------------------------===//
+
+ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser,
+                                             OperationState &result) const {
+  return parseAssemblyFn(parser, result);
+}
+
 /// Look up the specified operation in the operation set and return a pointer
-/// to it if present.  Otherwise, return a null pointer.
+/// to it if present. Otherwise, return a null pointer.
 const AbstractOperation *AbstractOperation::lookup(StringRef opName,
                                                    MLIRContext *context) {
   auto &impl = context->getImpl();
@@ -683,26 +680,45 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName,
   return nullptr;
 }
 
+void AbstractOperation::insert(
+    StringRef name, Dialect &dialect, OperationProperties opProperties,
+    TypeID typeID, ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly,
+    VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
+    GetCanonicalizationPatternsFn getCanonicalizationPatterns,
+    detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait) {
+  AbstractOperation opInfo(name, dialect, opProperties, typeID, parseAssembly,
+                           printAssembly, verifyInvariants, foldHook,
+                           getCanonicalizationPatterns, std::move(interfaceMap),
+                           hasTrait);
+
+  auto &impl = dialect.getContext()->getImpl();
+  assert(impl.multiThreadedExecutionContext == 0 &&
+         "Registering a new operation kind while in a multi-threaded execution "
+         "context");
+  if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) {
+    llvm::errs() << "error: operation named '" << name
+                 << "' is already registered.\n";
+    abort();
+  }
+}
+
 AbstractOperation::AbstractOperation(
     StringRef name, Dialect &dialect, OperationProperties opProperties,
-    TypeID typeID,
-    ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result),
-    void (&printAssembly)(Operation *op, OpAsmPrinter &p),
-    LogicalResult (&verifyInvariants)(Operation *op),
-    LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
-                              SmallVectorImpl<OpFoldResult> &results),
-    void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
-                                        MLIRContext *context),
-    detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID))
+    TypeID typeID, ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly,
+    VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
+    GetCanonicalizationPatternsFn getCanonicalizationPatterns,
+    detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait)
     : name(Identifier::get(name, dialect.getContext())), dialect(dialect),
-      typeID(typeID), parseAssembly(parseAssembly),
-      printAssembly(printAssembly), verifyInvariants(verifyInvariants),
-      foldHook(foldHook),
-      getCanonicalizationPatterns(getCanonicalizationPatterns),
-      opProperties(opProperties), interfaceMap(std::move(interfaceMap)),
-      hasRawTrait(hasTrait) {}
-
-/// Get the dialect that registered the type with the provided typeid.
+      typeID(typeID), opProperties(opProperties),
+      interfaceMap(std::move(interfaceMap)), foldHookFn(foldHook),
+      getCanonicalizationPatternsFn(getCanonicalizationPatterns),
+      hasTraitFn(hasTrait), parseAssemblyFn(parseAssembly),
+      printAssemblyFn(printAssembly), verifyInvariantsFn(verifyInvariants) {}
+
+//===----------------------------------------------------------------------===//
+// AbstractType
+//===----------------------------------------------------------------------===//
+
 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
   auto &impl = context->getImpl();
   auto it = impl.registeredTypes.find(typeID);
index fe86c6f..c85efe8 100644 (file)
@@ -649,7 +649,7 @@ ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 // The fallback for the printer is to print in the generic assembly form.
-void OpState::print(OpAsmPrinter &p) { p.printGenericOp(getOperation()); }
+void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); }
 
 /// Emit an error about fatal conditions with this operation, reporting up to
 /// any diagnostic handlers that may be listening.