/// Represent one application of createLinalgStrategyVectorizePass.
struct Vectorize : public Transformation {
explicit Vectorize(linalg::LinalgVectorizationOptions options,
- LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(f), opName(), options(options) {}
+ LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool padVectorize = false)
+ : Transformation(f), opName(), options(options),
+ vectorizePadding(padVectorize) {}
Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
- LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(f), opName(name), options(options) {}
+ LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool padVectorize = false)
+ : Transformation(f), opName(name), options(options),
+ vectorizePadding(padVectorize) {}
void addToPassPipeline(OpPassManager &pm,
LinalgTransformationFilter m) const override {
- pm.addPass(createLinalgStrategyVectorizePass(opName, options, m));
+ pm.addPass(createLinalgStrategyVectorizePass(opName, options, m,
+ vectorizePadding));
}
private:
std::string opName;
linalg::LinalgVectorizationOptions options;
+ bool vectorizePadding;
};
/// Represent one application of createLinalgStrategyLowerVectorsPass.
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
CodegenStrategy &
vectorize(StringRef opName,
- LinalgTransformationFilter::FilterFunction f = nullptr) {
+ LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool vectorizePadding = false) {
assert(!opName.empty() && "expected an op name");
transformationSequence.emplace_back(std::make_unique<Vectorize>(
- opName, linalg::LinalgVectorizationOptions(), f));
+ opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding));
return *this;
}
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
/// operation.
CodegenStrategy &
vectorizeIf(bool b, StringRef opName,
- LinalgTransformationFilter::FilterFunction f = nullptr) {
- return b ? vectorize(opName, f) : *this;
+ LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool vectorizePadding = false) {
+ return b ? vectorize(opName, f, vectorizePadding) : *this;
return *this;
}
/// Append a pattern to lower all vector operations.
LinalgStrategyVectorizePass() = default;
LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
- LinalgTransformationFilter filt)
+ LinalgTransformationFilter filt,
+ bool padVectorize = false)
: options(opt), filter(filt) {
this->anchorOpName.setValue(opName.str());
+ this->vectorizePadding.setValue(padVectorize);
}
void runOnFunction() override {
vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(
funcOp.getContext(), /*benefit=*/2);
+ if (vectorizePadding) {
+ linalg::populatePadTensorOpVectorizationPatterns(vectorizationPatterns);
+ }
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorizationPatterns));
}
}
/// Create a LinalgStrategyVectorizePass.
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createLinalgStrategyVectorizePass(StringRef opName,
- LinalgVectorizationOptions opt,
- LinalgTransformationFilter filter) {
- return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter);
+std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgStrategyVectorizePass(
+ StringRef opName, LinalgVectorizationOptions opt,
+ LinalgTransformationFilter filter, bool padVectorize) {
+ return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
+ padVectorize);
}
/// Create a LinalgStrategyEnablePass.