/// Return the operation that this refers to.
Operation *getOperation() { return state; }
+ /// Return the dialect that this refers to.
+ Dialect *getDialect() { return getOperation()->getDialect(); }
+
+ /// Return the parent Region of this operation.
+ Region *getParentRegion() { return getOperation()->getParentRegion(); }
+
/// Returns the closest surrounding operation that contains this operation
/// or nullptr if this is a top-level operation.
Operation *getParentOp() { return getOperation()->getParentOp(); }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
// The fallback for the printer is to print it the generic assembly form.
- void print(OpAsmPrinter &p);
+ static void print(Operation *op, OpAsmPrinter &p);
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
/// so we can cast it away here.
private:
Operation *state;
+
+ /// Allow access to internal hook implementation methods.
+ friend AbstractOperation;
};
// Allow comparing operators.
return os;
}
-/// This template defines the foldHook as used by AbstractOperation.
-///
-/// The default implementation uses a general fold method that can be defined on
-/// custom ops which can return multiple results.
-template <typename ConcreteType, bool isSingleResult, typename = void>
-class FoldingHook {
-public:
- /// This is an implementation detail of the constant folder hook for
- /// AbstractOperation.
- static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- auto operationFoldResult = cast<ConcreteType>(op).fold(operands, results);
- // Failure to fold or in place fold both mean we can continue folding.
- if (failed(operationFoldResult) || results.empty()) {
- auto traitFoldResult = ConcreteType::foldTraits(op, operands, results);
- // Only return the trait fold result if it is a success since
- // operationFoldResult might have been a success originally.
- if (succeeded(traitFoldResult))
- return traitFoldResult;
- }
- return operationFoldResult;
- }
-
- /// This hook implements a generalized folder for this operation. Operations
- /// can implement this to provide simplifications rules that are applied by
- /// the Builder::createOrFold API and the canonicalization pass.
- ///
- /// This is an intentionally limited interface - implementations of this hook
- /// can only perform the following changes to the operation:
- ///
- /// 1. They can leave the operation alone and without changing the IR, and
- /// return failure.
- /// 2. They can mutate the operation in place, without changing anything else
- /// in the IR. In this case, return success.
- /// 3. They can return a list of existing values that can be used instead of
- /// the operation. In this case, fill in the results list and return
- /// success. The caller will remove the operation and use those results
- /// instead.
- ///
- /// This allows expression of some simple in-place canonicalizations (e.g.
- /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
- /// generalized constant folding.
- ///
- /// If not overridden, this fallback implementation always fails to fold.
- ///
- LogicalResult fold(ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- return failure();
- }
-};
-
-/// This template specialization defines the foldHook as used by
-/// AbstractOperation for single-result operations. This gives the hook a nicer
-/// signature that is easier to implement.
-template <typename ConcreteType, bool isSingleResult>
-class FoldingHook<ConcreteType, isSingleResult,
- typename std::enable_if<isSingleResult>::type> {
-public:
- /// If the operation returns a single value, then the Op can be implicitly
- /// converted to an Value. This yields the value of the only result.
- operator Value() {
- return static_cast<ConcreteType *>(this)->getOperation()->getResult(0);
- }
-
- /// This is an implementation detail of the constant folder hook for
- /// AbstractOperation.
- static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- auto result = cast<ConcreteType>(op).fold(operands);
- // Failure to fold or in place fold both mean we can continue folding.
- if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
- // Only consider the trait fold result if it is a success since
- // the operation fold might have been a success originally.
- if (auto traitFoldResult = ConcreteType::foldTraits(op, operands))
- result = traitFoldResult;
- }
-
- if (!result)
- return failure();
-
- // Check if the operation was folded in place. In this case, the operation
- // returns itself.
- if (result.template dyn_cast<Value>() != op->getResult(0))
- results.push_back(result);
- return success();
- }
-
- /// This hook implements a generalized folder for this operation. Operations
- /// can implement this to provide simplifications rules that are applied by
- /// the Builder::createOrFold API and the canonicalization pass.
- ///
- /// This is an intentionally limited interface - implementations of this hook
- /// can only perform the following changes to the operation:
- ///
- /// 1. They can leave the operation alone and without changing the IR, and
- /// return nullptr.
- /// 2. They can mutate the operation in place, without changing anything else
- /// in the IR. In this case, return the operation itself.
- /// 3. They can return an existing SSA value that can be used instead of
- /// the operation. In this case, return that value. The caller will
- /// remove the operation and use that result instead.
- ///
- /// This allows expression of some simple in-place canonicalizations (e.g.
- /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
- /// generalized constant folding.
- ///
- /// If not overridden, this fallback implementation always fails to fold.
- ///
- OpFoldResult fold(ArrayRef<Attribute> operands) { return {}; }
-};
-
//===----------------------------------------------------------------------===//
// Operation Trait Types
//===----------------------------------------------------------------------===//
auto *base = static_cast<OpState *>(concrete);
return base->getOperation();
}
-
- /// Provide default implementations of trait hooks. This allows traits to
- /// provide exactly the overrides they care about.
- static LogicalResult verifyTrait(Operation *op) { return success(); }
- static AbstractOperation::OperationProperties getTraitProperties() {
- return 0;
- }
-
- static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
- SmallVector<OpFoldResult, 1> results;
- if (failed(foldTrait(op, operands, results)))
- return {};
- if (results.empty())
- return op->getResult(0);
- assert(results.size() == 1 &&
- "Single result op cannot return multiple fold results");
-
- return results[0];
- }
-
- static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- return failure();
- }
};
//===----------------------------------------------------------------------===//
Value getResult() { return this->getOperation()->getResult(0); }
Type getType() { return getResult().getType(); }
+ /// If the operation returns a single value, then the Op can be implicitly
+ /// converted to an Value. This yields the value of the only result.
+ operator Value() { return getResult(); }
+
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
} // end namespace OpTrait
//===----------------------------------------------------------------------===//
+// Internal Trait Utilities
+//===----------------------------------------------------------------------===//
+
+namespace op_definition_impl {
+//===----------------------------------------------------------------------===//
+// Trait Existence
+
+/// Returns true if this given Trait ID matches the IDs of any of the provided
+/// trait types `Traits`.
+template <template <typename T> class... Traits>
+static bool hasTrait(TypeID traitID) {
+ TypeID traitIDs[] = {TypeID::get<Traits>()...};
+ for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
+ if (traitIDs[i] == traitID)
+ return true;
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Trait Folding
+
+/// Trait to check if T provides a 'foldTrait' method for single result
+/// operations.
+template <typename T, typename... Args>
+using has_single_result_fold_trait = decltype(T::foldTrait(
+ std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
+template <typename T>
+using detect_has_single_result_fold_trait =
+ llvm::is_detected<has_single_result_fold_trait, T>;
+/// Trait to check if T provides a general 'foldTrait' method.
+template <typename T, typename... Args>
+using has_fold_trait =
+ decltype(T::foldTrait(std::declval<Operation *>(),
+ std::declval<ArrayRef<Attribute>>(),
+ std::declval<SmallVectorImpl<OpFoldResult> &>()));
+template <typename T>
+using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
+/// Trait to check if T provides any `foldTrait` method.
+/// NOTE: This should use std::disjunction when C++17 is available.
+template <typename T>
+using detect_has_any_fold_trait =
+ std::conditional_t<bool(detect_has_fold_trait<T>::value),
+ detect_has_fold_trait<T>,
+ detect_has_single_result_fold_trait<T>>;
+
+/// Returns the result of folding a trait that implements a `foldTrait` function
+/// that is specialized for operations that have a single result.
+template <typename Trait>
+static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
+ LogicalResult>
+foldTrait(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ assert(op->hasTrait<OpTrait::OneResult>() &&
+ "expected trait on non single-result operation to implement the "
+ "general `foldTrait` method");
+ // If a previous trait has already been folded and replaced this operation, we
+ // fail to fold this trait.
+ if (!results.empty())
+ return failure();
+
+ if (OpFoldResult result = Trait::foldTrait(op, operands)) {
+ if (result.template dyn_cast<Value>() != op->getResult(0))
+ results.push_back(result);
+ return success();
+ }
+ return failure();
+}
+/// Returns the result of folding a trait that implements a generalized
+/// `foldTrait` function that is supports any operation type.
+template <typename Trait>
+static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
+foldTrait(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // If a previous trait has already been folded and replaced this operation, we
+ // fail to fold this trait.
+ return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
+}
+
+/// The internal implementation of `foldTraits` below that returns the result of
+/// folding a set of trait types `Ts` that implement a `foldTrait` method.
+template <typename... Ts>
+static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results,
+ std::tuple<Ts...> *) {
+ bool anyFolded = false;
+ (void)std::initializer_list<int>{
+ (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
+ return success(anyFolded);
+}
+
+/// Given a tuple type containing a set of traits that contain a `foldTrait`
+/// method, return the result of folding the given operation.
+template <typename TraitTupleT>
+static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
+foldTraits(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
+}
+/// A variant of the method above that is specialized when there are no traits
+/// that contain a `foldTrait` method.
+template <typename TraitTupleT>
+static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
+foldTraits(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// Trait Properties
+
+/// Trait to check if T provides a `getTraitProperties` method.
+template <typename T, typename... Args>
+using has_get_trait_properties = decltype(T::getTraitProperties());
+template <typename T>
+using detect_has_get_trait_properties =
+ llvm::is_detected<has_get_trait_properties, T>;
+
+/// The internal implementation of `getTraitProperties` below that returns the
+/// OR of invoking `getTraitProperties` on all of the provided trait types `Ts`.
+template <typename... Ts>
+static AbstractOperation::OperationProperties
+getTraitPropertiesImpl(std::tuple<Ts...> *) {
+ AbstractOperation::OperationProperties result = 0;
+ (void)std::initializer_list<int>{(result |= Ts::getTraitProperties(), 0)...};
+ return result;
+}
+
+/// Given a tuple type containing a set of traits that contain a
+/// `getTraitProperties` method, return the OR of all of the results of invoking
+/// those methods.
+template <typename TraitTupleT>
+static AbstractOperation::OperationProperties getTraitProperties() {
+ return getTraitPropertiesImpl((TraitTupleT *)nullptr);
+}
+
+//===----------------------------------------------------------------------===//
+// Trait Verification
+
+/// Trait to check if T provides a `verifyTrait` method.
+template <typename T, typename... Args>
+using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
+template <typename T>
+using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
+
+/// The internal implementation of `verifyTraits` below that returns the result
+/// of verifying the current operation with all of the provided trait types
+/// `Ts`.
+template <typename... Ts>
+static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
+ LogicalResult result = success();
+ (void)std::initializer_list<int>{
+ (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
+ return result;
+}
+
+/// Given a tuple type containing a set of traits that contain a
+/// `verifyTrait` method, return the result of verifying the given operation.
+template <typename TraitTupleT>
+static LogicalResult verifyTraits(Operation *op) {
+ return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
+}
+} // namespace op_definition_impl
+
+//===----------------------------------------------------------------------===//
// Operation Definition classes
//===----------------------------------------------------------------------===//
/// argument 'ConcreteType' should be the concrete type by CRTP and the others
/// are base classes by the policy pattern.
template <typename ConcreteType, template <typename T> class... Traits>
-class Op : public OpState,
- public Traits<ConcreteType>...,
- public FoldingHook<ConcreteType,
- llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
- Traits<ConcreteType>...>::value> {
+class Op : public OpState, public Traits<ConcreteType>... {
public:
+ /// Inherit getOperation from `OpState`.
+ using OpState::getOperation;
+
/// Return if this operation contains the provided trait.
template <template <typename T> class Trait>
static constexpr bool hasTrait() {
return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
}
- /// Return the operation that this refers to.
- Operation *getOperation() { return OpState::getOperation(); }
-
/// Create a deep copy of this operation.
ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
}
- /// Return the dialect that this refers to.
- Dialect *getDialect() { return getOperation()->getDialect(); }
-
- /// Return the parent Region of this operation.
- Region *getParentRegion() { return getOperation()->getParentRegion(); }
-
/// Return true if this "op class" can match against the specified operation.
static bool classof(Operation *op) {
if (auto *abstractOp = op->getAbstractOperation())
return false;
}
- /// This is the hook used by the AsmParser to parse the custom form of this
- /// op from an .mlir file. Op implementations should provide a parse method,
- /// which returns failure. On success, they should return fill in result with
- /// the fields to use.
- static ParseResult parseAssembly(OpAsmParser &parser,
- OperationState &result) {
- return ConcreteType::parse(parser, result);
- }
-
- /// This is the hook used by the AsmPrinter to emit this to the .mlir file.
- /// Op implementations should provide a print method.
- static void printAssembly(Operation *op, OpAsmPrinter &p) {
- auto opPointer = dyn_cast<ConcreteType>(op);
- assert(opPointer &&
- "op's name does not match name of concrete type instantiated with");
- opPointer.print(p);
- }
-
- /// This is the hook that checks whether or not this operation is well
- /// formed according to the invariants of its opcode. It delegates to the
- /// Traits for their policy implementations, and allows the user to specify
- /// their own verify() method.
- ///
- /// On success this returns false; on failure it emits an error to the
- /// diagnostic subsystem and returns true.
- static LogicalResult verifyInvariants(Operation *op) {
- return failure(
- failed(BaseVerifier<Traits<ConcreteType>...>::verifyTrait(op)) ||
- failed(cast<ConcreteType>(op).verify()));
- }
-
- /// This is the hook that tries to fold the given operation according to its
- /// traits. It delegates to the Traits for their policy implementations, and
- /// allows the user to specify their own fold() method.
- static OpFoldResult foldTraits(Operation *op, ArrayRef<Attribute> operands) {
- return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands);
- }
-
- static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands,
- results);
- }
-
- // Returns the properties of an operation by combining the properties of the
- // traits of the op.
- static AbstractOperation::OperationProperties getOperationProperties() {
- return BaseProperties<Traits<ConcreteType>...>::getTraitProperties();
- }
-
/// Expose the type we are instantiated on to template machinery that may want
/// to introspect traits on this operation.
using ConcreteOpType = ConcreteType;
}
private:
- template <typename... Types> struct BaseVerifier;
-
- template <typename First, typename... Rest>
- struct BaseVerifier<First, Rest...> {
- static LogicalResult verifyTrait(Operation *op) {
- return failure(failed(First::verifyTrait(op)) ||
- failed(BaseVerifier<Rest...>::verifyTrait(op)));
- }
- };
-
- template <typename...> struct BaseVerifier {
- static LogicalResult verifyTrait(Operation *op) { return success(); }
- };
-
- template <typename... Types> struct BaseProperties;
-
- template <typename First, typename... Rest>
- struct BaseProperties<First, Rest...> {
- static AbstractOperation::OperationProperties getTraitProperties() {
- return First::getTraitProperties() |
- BaseProperties<Rest...>::getTraitProperties();
- }
- };
-
- template <typename... Types>
- struct BaseFolder;
-
- template <typename First, typename... Rest>
- struct BaseFolder<First, Rest...> {
- static OpFoldResult foldTraits(Operation *op,
- ArrayRef<Attribute> operands) {
- auto result = First::foldTrait(op, operands);
- // Failure to fold or in place fold both mean we can continue folding.
- if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
- // Only consider the trait fold result if it is a success since
- // the operation fold might have been a success originally.
- auto resultRemaining = BaseFolder<Rest...>::foldTraits(op, operands);
- if (resultRemaining)
- result = resultRemaining;
- }
-
- return result;
- }
-
- static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- auto result = First::foldTrait(op, operands, results);
- // Failure to fold or in place fold both mean we can continue folding.
- if (failed(result) || results.empty()) {
- auto resultRemaining =
- BaseFolder<Rest...>::foldTraits(op, operands, results);
- if (succeeded(resultRemaining))
- result = resultRemaining;
- }
+ /// Trait to check if T provides a 'fold' method for a single result op.
+ template <typename T, typename... Args>
+ using has_single_result_fold =
+ decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
+ template <typename T>
+ using detect_has_single_result_fold =
+ llvm::is_detected<has_single_result_fold, T>;
+ /// Trait to check if T provides a general 'fold' method.
+ template <typename T, typename... Args>
+ using has_fold = decltype(
+ std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
+ std::declval<SmallVectorImpl<OpFoldResult> &>()));
+ template <typename T> using detect_has_fold = llvm::is_detected<has_fold, T>;
+ /// Trait to check if T provides a 'print' method.
+ template <typename T, typename... Args>
+ using has_print =
+ decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
+ template <typename T>
+ using detect_has_print = llvm::is_detected<has_print, T>;
+ /// A tuple type containing the traits that have a `foldTrait` function.
+ using FoldableTraitsTupleT = typename detail::FilterTypes<
+ op_definition_impl::detect_has_any_fold_trait,
+ Traits<ConcreteType>...>::type;
+ /// A tuple type containing the traits that have a verify function.
+ using VerifiableTraitsTupleT =
+ typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
+ Traits<ConcreteType>...>::type;
+
+ /// Returns the properties of this operation by combining the properties
+ /// defined by the traits.
+ static AbstractOperation::OperationProperties getOperationProperties() {
+ return op_definition_impl::getTraitProperties<typename detail::FilterTypes<
+ op_definition_impl::detect_has_get_trait_properties,
+ Traits<ConcreteType>...>::type>();
+ }
- return result;
- }
- };
+ /// Returns an interface map containing the interfaces registered to this
+ /// operation.
+ static detail::InterfaceMap getInterfaceMap() {
+ return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
+ }
- template <typename...>
- struct BaseFolder {
- static OpFoldResult foldTraits(Operation *op,
- ArrayRef<Attribute> operands) {
- return {};
- }
- static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- return failure();
+ /// Return the internal implementations of each of the AbstractOperation
+ /// hooks.
+ /// Implementation of `FoldHookFn` AbstractOperation hook.
+ static AbstractOperation::FoldHookFn getFoldHookFn() {
+ return getFoldHookFnImpl<ConcreteType>();
+ }
+ /// The internal implementation of `getFoldHookFn` above that is invoked if
+ /// the operation is single result and defines a `fold` method.
+ template <typename ConcreteOpT>
+ static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
+ Traits<ConcreteOpT>...>::value &&
+ detect_has_single_result_fold<ConcreteOpT>::value,
+ AbstractOperation::FoldHookFn>
+ getFoldHookFnImpl() {
+ return &foldSingleResultHook<ConcreteOpT>;
+ }
+ /// The internal implementation of `getFoldHookFn` above that is invoked if
+ /// the operation is not single result and defines a `fold` method.
+ template <typename ConcreteOpT>
+ static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
+ Traits<ConcreteOpT>...>::value &&
+ detect_has_fold<ConcreteOpT>::value,
+ AbstractOperation::FoldHookFn>
+ getFoldHookFnImpl() {
+ return &foldHook<ConcreteOpT>;
+ }
+ /// The internal implementation of `getFoldHookFn` above that is invoked if
+ /// the operation does not define a `fold` method.
+ template <typename ConcreteOpT>
+ static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
+ !detect_has_fold<ConcreteOpT>::value,
+ AbstractOperation::FoldHookFn>
+ getFoldHookFnImpl() {
+ // In this case, we only need to fold the traits of the operation.
+ return &op_definition_impl::foldTraits<FoldableTraitsTupleT>;
+ }
+ /// Return the result of folding a single result operation that defines a
+ /// `fold` method.
+ template <typename ConcreteOpT>
+ static LogicalResult
+ foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
+
+ // If the fold failed or was in-place, try to fold the traits of the
+ // operation.
+ if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
+ if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
+ op, operands, results)))
+ return success();
+ return success(static_cast<bool>(result));
}
- };
+ results.push_back(result);
+ return success();
+ }
+ /// Return the result of folding an operation that defines a `fold` method.
+ template <typename ConcreteOpT>
+ static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
- template <typename...> struct BaseProperties {
- static AbstractOperation::OperationProperties getTraitProperties() {
- return 0;
+ // If the fold failed or was in-place, try to fold the traits of the
+ // operation.
+ if (failed(result) || results.empty()) {
+ if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
+ op, operands, results)))
+ return success();
}
- };
-
- /// Returns true if this operation contains the trait for the given typeID.
- static bool hasTrait(TypeID traitID) {
- return llvm::is_contained(llvm::makeArrayRef({TypeID::get<Traits>()...}),
- traitID);
+ return result;
+ }
+
+ /// Implementation of `GetCanonicalizationPatternsFn` AbstractOperation hook.
+ static AbstractOperation::GetCanonicalizationPatternsFn
+ getGetCanonicalizationPatternsFn() {
+ return &ConcreteType::getCanonicalizationPatterns;
+ }
+ /// Implementation of `GetHasTraitFn`
+ static AbstractOperation::HasTraitFn getHasTraitFn() {
+ return &op_definition_impl::hasTrait<Traits...>;
+ }
+ /// Implementation of `ParseAssemblyFn` AbstractOperation hook.
+ static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
+ return &ConcreteType::parse;
+ }
+ /// Implementation of `PrintAssemblyFn` AbstractOperation hook.
+ static AbstractOperation::PrintAssemblyFn getPrintAssemblyFn() {
+ return getPrintAssemblyFnImpl<ConcreteType>();
+ }
+ /// The internal implementation of `getPrintAssemblyFn` that is invoked when
+ /// the concrete operation does not define a `print` method.
+ template <typename ConcreteOpT>
+ static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
+ AbstractOperation::PrintAssemblyFn>
+ getPrintAssemblyFnImpl() {
+ return &OpState::print;
+ }
+ /// The internal implementation of `getPrintAssemblyFn` that is invoked when
+ /// the concrete operation defines a `print` method.
+ template <typename ConcreteOpT>
+ static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
+ AbstractOperation::PrintAssemblyFn>
+ getPrintAssemblyFnImpl() {
+ return &printAssembly;
}
-
- /// Returns an interface map for the interfaces registered to this operation.
- static detail::InterfaceMap getInterfaceMap() {
- return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
+ static void printAssembly(Operation *op, OpAsmPrinter &p) {
+ return cast<ConcreteType>(op).print(p);
+ }
+ /// Implementation of `VerifyInvariantsFn` AbstractOperation hook.
+ static AbstractOperation::VerifyInvariantsFn getVerifyInvariantsFn() {
+ return &verifyInvariants;
+ }
+ static LogicalResult verifyInvariants(Operation *op) {
+ return failure(
+ failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
+ failed(cast<ConcreteType>(op).verify()));
}
- /// Allow access to 'hasTrait' and 'getInterfaceMap'.
+ /// Allow access to internal implementation methods.
friend AbstractOperation;
};
public:
using OperationProperties = uint32_t;
+ using GetCanonicalizationPatternsFn = void (*)(OwningRewritePatternList &,
+ MLIRContext *);
+ using FoldHookFn = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &);
+ using HasTraitFn = bool (*)(TypeID);
+ using ParseAssemblyFn = ParseResult (*)(OpAsmParser &, OperationState &);
+ using PrintAssemblyFn = void (*)(Operation *, OpAsmPrinter &);
+ using VerifyInvariantsFn = LogicalResult (*)(Operation *);
+
/// This is the name of the operation.
const Identifier name;
TypeID typeID;
/// Use the specified object to parse this ops custom assembly format.
- ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result);
+ ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
/// This hook implements the AsmPrinter for this operation.
- void (&printAssembly)(Operation *op, OpAsmPrinter &p);
+ void printAssembly(Operation *op, OpAsmPrinter &p) const {
+ return printAssemblyFn(op, p);
+ }
/// This hook implements the verifier for this operation. It should emits an
/// error message and returns failure if a problem is detected, or returns
/// success if everything is ok.
- LogicalResult (&verifyInvariants)(Operation *op);
+ LogicalResult verifyInvariants(Operation *op) const {
+ return verifyInvariantsFn(op);
+ }
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
/// generalized constant folding.
- LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results);
+ LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) const {
+ return foldHookFn(op, operands, results);
+ }
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
- void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
- MLIRContext *context);
+ void getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) const {
+ return getCanonicalizationPatternsFn(results, context);
+ }
/// Returns whether the operation has a particular property.
bool hasProperty(OperationProperty property) const {
/// Returns true if the operation has a particular trait.
template <template <typename T> class Trait> bool hasTrait() const {
- return hasRawTrait(TypeID::get<Trait>());
+ return hasTraitFn(TypeID::get<Trait>());
}
/// Look up the specified operation in the specified MLIRContext and return a
/// This constructor is used by Dialect objects when they register the list of
/// operations they contain.
- template <typename T> static AbstractOperation get(Dialect &dialect) {
- return AbstractOperation(
- T::getOperationName(), dialect, T::getOperationProperties(),
- TypeID::get<T>(), T::parseAssembly, T::printAssembly,
- T::verifyInvariants, T::foldHook, T::getCanonicalizationPatterns,
- T::getInterfaceMap(), T::hasTrait);
+ template <typename T> static void insert(Dialect &dialect) {
+ insert(T::getOperationName(), dialect, T::getOperationProperties(),
+ TypeID::get<T>(), T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
+ T::getVerifyInvariantsFn(), T::getFoldHookFn(),
+ T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
+ T::getHasTraitFn());
}
private:
- AbstractOperation(
- StringRef name, Dialect &dialect, OperationProperties opProperties,
- TypeID typeID,
- ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result),
- void (&printAssembly)(Operation *op, OpAsmPrinter &p),
- LogicalResult (&verifyInvariants)(Operation *op),
- LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results),
- void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
- MLIRContext *context),
- detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID));
+ static void insert(StringRef name, Dialect &dialect,
+ OperationProperties opProperties, TypeID typeID,
+ ParseAssemblyFn parseAssembly,
+ PrintAssemblyFn printAssembly,
+ VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
+ GetCanonicalizationPatternsFn getCanonicalizationPatterns,
+ detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
+
+ AbstractOperation(StringRef name, Dialect &dialect,
+ OperationProperties opProperties, TypeID typeID,
+ ParseAssemblyFn parseAssembly,
+ PrintAssemblyFn printAssembly,
+ VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
+ GetCanonicalizationPatternsFn getCanonicalizationPatterns,
+ detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
/// The properties of the operation.
const OperationProperties opProperties;
/// A map of interfaces that were registered to this operation.
detail::InterfaceMap interfaceMap;
- /// This hook returns if the operation contains the trait corresponding
- /// to the given TypeID.
- bool (&hasRawTrait)(TypeID traitID);
+ /// Internal callback hooks provided by the op implementation.
+ FoldHookFn foldHookFn;
+ GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+ HasTraitFn hasTraitFn;
+ ParseAssemblyFn parseAssemblyFn;
+ PrintAssemblyFn printAssemblyFn;
+ VerifyInvariantsFn verifyInvariantsFn;
};
//===----------------------------------------------------------------------===//