#include "llvm/ADT/ArrayRef.h"
namespace mlir {
+class FunctionPassBase;
class ModulePassBase;
namespace linalg {
-ModulePassBase *createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
+FunctionPassBase *createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
ModulePassBase *createLowerLinalgToLLVMPass();
-
} // namespace linalg
} // namespace mlir
}
namespace {
-struct LinalgTilingPass : public ModulePass<LinalgTilingPass> {
+struct LinalgTilingPass : public FunctionPass<LinalgTilingPass> {
LinalgTilingPass();
LinalgTilingPass(ArrayRef<int64_t> sizes);
- void runOnModule() {
- for (auto &f : getModule())
- tileLinalgOps(f, tileSizes);
- }
+ void runOnFunction() { tileLinalgOps(getFunction(), tileSizes); }
SmallVector<int64_t, 8> tileSizes;
};
this->tileSizes.assign(sizes.begin(), sizes.end());
}
-ModulePassBase *
+FunctionPassBase *
mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
return new LinalgTilingPass(tileSizes);
}