void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
MLIRContext *context, StringRef opName,
linalg::LinalgTransformationFilter m) {
- assert(opName.empty() ||
- opName == ConcreteOpType::getOperationName() &&
- "explicit name must match ConcreteOpType::getOperationName");
+ assert(opName == ConcreteOpType::getOperationName() &&
+ "explicit name must match ConcreteOpType::getOperationName");
patterList.insert<PatternType<ConcreteOpType>>(context, options, m);
}
struct Tile : public Transformation {
explicit Tile(linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(f), opName(""), options(options) {}
+ : Transformation(f), opName(LinalgOpType::getOperationName()),
+ options(options) {}
Tile(StringRef name, linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
explicit Promote(
linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(f), opName(""), options(options) {}
+ : Transformation(f), opName(LinalgOpType::getOperationName()),
+ options(options) {}
Promote(StringRef name, linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
explicit Vectorize(
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(f), opName(""), options(options) {}
+ : Transformation(f), opName(LinalgOpType::getOperationName()),
+ options(options) {}
Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
void runOnFunction() override;
+ template <typename OpType>
+ void runStrategy(LinalgTilingOptions tilingOptions,
+ LinalgTilingOptions registerTilingOptions,
+ vector::VectorContractLowering vectorContractLowering,
+ vector::VectorTransferSplit vectorTransferSplit);
+
ListOption<int64_t> tileSizes{*this, "tile-sizes",
llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Specifies the tile sizes.")};
};
} // end anonymous namespace
+template <>
+void TestLinalgCodegenStrategy::runStrategy<LinalgOp>(
+ LinalgTilingOptions tilingOptions,
+ LinalgTilingOptions registerTilingOptions,
+ vector::VectorContractLowering vectorContractLowering,
+ vector::VectorTransferSplit vectorTransferSplit) {
+ assert(!anchorOpName.empty());
+ CodegenStrategy strategy;
+ strategy.tileIf<LinalgOp>(!tileSizes.empty(), anchorOpName, tilingOptions)
+ .promoteIf<LinalgOp>(promote, anchorOpName,
+ LinalgPromotionOptions()
+ .setAlignment(16)
+ .setUseFullTileBuffersByDefault(promoteFullTile))
+ .tileIf<LinalgOp>(!registerTileSizes.empty(), anchorOpName,
+ registerTilingOptions)
+ .promoteIf<LinalgOp>(
+ registerPromote, anchorOpName,
+ LinalgPromotionOptions()
+ .setAlignment(16)
+ .setUseFullTileBuffersByDefault(registerPromoteFullTile))
+ .vectorizeIf<LinalgOp>(vectorize, anchorOpName)
+ .setVectorTransformsOptions(
+ vector::VectorTransformsOptions()
+ .setVectorTransformsOptions(vectorContractLowering)
+ .setVectorTransferSplit(vectorTransferSplit))
+ .setVectorTransferToSCFOptions(
+ VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
+ strategy.transform(getFunction());
+}
+
+template <typename OpType>
+void TestLinalgCodegenStrategy::runStrategy(
+ LinalgTilingOptions tilingOptions,
+ LinalgTilingOptions registerTilingOptions,
+ vector::VectorContractLowering vectorContractLowering,
+ vector::VectorTransferSplit vectorTransferSplit) {
+ CodegenStrategy strategy;
+ strategy.tileIf<OpType>(!tileSizes.empty(), tilingOptions)
+ .template promoteIf<OpType>(
+ promote, LinalgPromotionOptions()
+ .setAlignment(16)
+ .setUseFullTileBuffersByDefault(promoteFullTile))
+ .template tileIf<OpType>(!registerTileSizes.empty(),
+ registerTilingOptions)
+ .template promoteIf<OpType>(
+ registerPromote,
+ LinalgPromotionOptions()
+ .setAlignment(16)
+ .setUseFullTileBuffersByDefault(registerPromoteFullTile))
+ .template vectorizeIf<OpType>(vectorize)
+ .setVectorTransformsOptions(
+ vector::VectorTransformsOptions()
+ .setVectorTransformsOptions(vectorContractLowering)
+ .setVectorTransferSplit(vectorTransferSplit))
+ .setVectorTransferToSCFOptions(
+ VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
+ strategy.transform(getFunction());
+}
+
/// Apply transformations specified as patterns.
void TestLinalgCodegenStrategy::runOnFunction() {
- linalg::LinalgTransformationFilter::FilterFunction filterOpName =
- [&](Operation *op) -> LogicalResult {
- return success(op->getName().getStringRef() == anchorOpName);
- };
LinalgTilingOptions tilingOptions;
if (!tileSizes.empty())
tilingOptions = tilingOptions.setTileSizes(tileSizes);
.Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
.Default(vector::VectorTransferSplit::None);
- CodegenStrategy strategy;
- strategy.tileIf<LinalgOp>(!tileSizes.empty(), anchorOpName, tilingOptions)
- .promoteIf<LinalgOp>(promote, anchorOpName,
- LinalgPromotionOptions()
- .setAlignment(16)
- .setUseFullTileBuffersByDefault(promoteFullTile),
- filterOpName)
- .tileIf<LinalgOp>(!registerTileSizes.empty(), anchorOpName,
- registerTilingOptions)
- .promoteIf<LinalgOp>(
- registerPromote, anchorOpName,
- LinalgPromotionOptions()
- .setAlignment(16)
- .setUseFullTileBuffersByDefault(registerPromoteFullTile))
- .vectorizeIf<LinalgOp>(vectorize, anchorOpName)
- .setVectorTransformsOptions(
- vector::VectorTransformsOptions()
- .setVectorTransformsOptions(vectorContractLowering)
- .setVectorTransferSplit(vectorTransferSplit))
- .setVectorTransferToSCFOptions(
- VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
- strategy.transform(getFunction());
+ // If no anchorOpNameis specified, just test that strategy applies properly to
+ // linalg::MatmulOp.
+ if (anchorOpName.empty())
+ runStrategy<linalg::MatmulOp>(tilingOptions, registerTilingOptions,
+ vectorContractLowering, vectorTransferSplit);
+ else
+ runStrategy<LinalgOp>(tilingOptions, registerTilingOptions,
+ vectorContractLowering, vectorTransferSplit);
}
namespace mlir {