/// Trait to check if T provides a 'fold' method for a single result op.
template <typename T, typename... Args>
using has_single_result_fold_t =
- decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
+ decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
template <typename T>
constexpr static bool has_single_result_fold_v =
llvm::is_detected<has_single_result_fold_t, T>::value;
/// Trait to check if T provides a general 'fold' method.
template <typename T, typename... Args>
using has_fold_t = decltype(std::declval<T>().fold(
- std::declval<typename T::FoldAdaptor>(),
+ std::declval<ArrayRef<Attribute>>(),
std::declval<SmallVectorImpl<OpFoldResult> &>()));
template <typename T>
constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value;
+ /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a
+ /// single result op.
+ template <typename T, typename... Args>
+ using has_fold_adaptor_single_result_fold_t =
+ decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
+ template <class T>
+ constexpr static bool has_fold_adaptor_single_result_v =
+ llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value;
+ /// Trait to check if T provides a general 'fold' method with a FoldAdaptor.
+ template <typename T, typename... Args>
+ using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold(
+ std::declval<typename T::FoldAdaptor>(),
+ std::declval<SmallVectorImpl<OpFoldResult> &>()));
+ template <class T>
+ constexpr static bool has_fold_adaptor_v =
+ llvm::is_detected<has_fold_adaptor_fold_t, T>::value;
/// Trait to check if T provides a 'print' method.
template <typename T, typename... Args>
// If the operation is single result and defines a `fold` method.
if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
Traits<ConcreteType>...>::value &&
- has_single_result_fold_v<ConcreteType>)
+ (has_single_result_fold_v<ConcreteType> ||
+ has_fold_adaptor_single_result_v<ConcreteType>))
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldSingleResultHook<ConcreteType>(op, operands, results);
};
// The operation is not single result and defines a `fold` method.
- if constexpr (has_fold_v<ConcreteType>)
+ if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>)
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldHook<ConcreteType>(op, operands, results);
static LogicalResult
foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- OpFoldResult result =
- cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
- operands, op->getAttrDictionary(), op->getRegions()));
+ OpFoldResult result;
+ if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>)
+ result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
+ operands, op->getAttrDictionary(), op->getRegions()));
+ else
+ result = cast<ConcreteOpT>(op).fold(operands);
// If the fold failed or was in-place, try to fold the traits of the
// operation.
template <typename ConcreteOpT>
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- LogicalResult result = cast<ConcreteOpT>(op).fold(
- typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
- op->getRegions()),
- results);
+ auto result = LogicalResult::failure();
+ if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
+ result = cast<ConcreteOpT>(op).fold(
+ typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
+ op->getRegions()),
+ results);
+ } else {
+ result = cast<ConcreteOpT>(op).fold(operands, results);
+ }
// If the fold failed or was in-place, try to fold the traits of the
// operation.