/// Patterns to apply `splitReduction` below.
void populateSplitReductionPattern(
RewritePatternSet &patterns,
- ControlSplitReductionFn controlSplitReductionFn,
- LinalgTransformationFilter f = LinalgTransformationFilter());
+ const ControlSplitReductionFn &controlSplitReductionFn,
+ const LinalgTransformationFilter &f = LinalgTransformationFilter());
/// Apply transformation to split the single linalg op reduction into a parallel
/// and reduction dimension. Then create a new linalg.generic op doing the rest
/// ```
FailureOr<LinalgOp>
splitReduction(PatternRewriter &b, LinalgOp op,
- ControlSplitReductionFn controlSplitReductionFn,
- LinalgTransformationFilter f);
+ const ControlSplitReductionFn &controlSplitReductionFn,
+ const LinalgTransformationFilter &f);
} // namespace linalg
} // namespace mlir
//
//===----------------------------------------------------------------------===//
+#include <utility>
+
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
return llvm::None;
}
-FailureOr<LinalgOp>
-mlir::linalg::splitReduction(PatternRewriter &b, LinalgOp op,
- ControlSplitReductionFn controlSplitReductionFn,
- LinalgTransformationFilter filter) {
+FailureOr<LinalgOp> mlir::linalg::splitReduction(
+ PatternRewriter &b, LinalgOp op,
+ const ControlSplitReductionFn &controlSplitReductionFn,
+ const LinalgTransformationFilter &filter) {
if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
!op.hasOnlyProjectedPermutations())
ControlSplitReductionFn controlSplitReductionFn,
LinalgTransformationFilter f, PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
- controlSplitReductionFn(controlSplitReductionFn), filter(std::move(f)) {
- }
+ controlSplitReductionFn(std::move(controlSplitReductionFn)),
+ filter(std::move(f)) {}
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
void linalg::populateSplitReductionPattern(
RewritePatternSet &patterns,
- ControlSplitReductionFn controlSplitReductionFn,
- LinalgTransformationFilter f) {
+ const ControlSplitReductionFn &controlSplitReductionFn,
+ const LinalgTransformationFilter &f) {
patterns.add<LinalgSplitReduction>(patterns.getContext(),
controlSplitReductionFn, f);
}