using namespace mlir::linalg;
using namespace mlir::linalg::intrinsics;
-using IndexedLinalgValue = TemplatedIndexedValue<std_load, std_store>;
+using IndexedStdValue = TemplatedIndexedValue<std_load, std_store>;
using edsc::op::operator+;
using edsc::op::operator==;
static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
AffineMap map,
ArrayRef<Value *> allViewSizes,
- OperationFolder *folder) {
+ OperationFolder *folder);
+SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
+ AffineMap map,
+ ArrayRef<Value *> allViewSizes,
+ OperationFolder *folder) {
// Apply `map` to get view sizes in loop order.
auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder);
// Create a new range with the applied tile sizes.
return res;
}
-template <typename LinalgOpType> class LinalgScopedEmitter {};
+template <typename IndexedValueType, typename LinalgOpType>
+class LinalgScopedEmitter {};
-template <> class LinalgScopedEmitter<CopyOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, CopyOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs, CopyOp copyOp,
OperationFolder *folder) {
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder);
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
- IndexedLinalgValue O(copyOp.getOutput(0)), I(copyOp.getInput(0));
+ IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
// clang-format off
}
};
-template <> class LinalgScopedEmitter<FillOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, FillOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs, FillOp fillOp,
OperationFolder *folder) {
assert(nPar == allIvs.size());
auto ivs =
SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
- IndexedLinalgValue O(fillOp.getOutput(0));
+ IndexedValueType O(fillOp.getOutput(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
nPar > 0 ? O(ivs) = ValueHandle(fillOp.getValue())
}
};
-template <> class LinalgScopedEmitter<DotOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, DotOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp,
OperationFolder *folder) {
assert(allIvs.size() == 1);
IndexHandle r_i(allIvs[0]);
- IndexedLinalgValue A(dotOp.getInput(0)), B(dotOp.getInput(1)),
+ IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
C(dotOp.getOutput(0));
// Emit scalar form.
C() = C() + A(r_i) * B(r_i);
}
};
-template <> class LinalgScopedEmitter<MatvecOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
MatvecOp matvecOp,
OperationFolder *folder) {
assert(allIvs.size() == 2);
IndexHandle i(allIvs[0]), r_j(allIvs[1]);
- IndexedLinalgValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
+ IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
C(matvecOp.getOutput(0));
// Emit scalar form.
C(i) = C(i) + A(i, r_j) * B(r_j);
}
};
-template <> class LinalgScopedEmitter<MatmulOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
MatmulOp matmulOp,
OperationFolder *folder) {
assert(allIvs.size() == 3);
IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
- IndexedLinalgValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
+ IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
C(matmulOp.getOutput(0));
// Emit scalar form.
C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
}
};
-template <> class LinalgScopedEmitter<ConvOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, ConvOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs, ConvOp convOp,
OperationFolder *folder) {
foldedAffineApplies(b, loc, maps[1], allIvs, folder));
SmallVector<ValueHandle, 8> oIdx(
foldedAffineApplies(b, loc, maps[2], allIvs, folder));
- IndexedLinalgValue F(convOp.filter()), I(convOp.input()),
- O(convOp.output());
+ IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
// Emit scalar form.
O(oIdx) += F(fIdx) * I(imIdx);
}
// loop.for %i = %c0 to %0 step %c1 {
// loop.for %j = %c0 to %1 step %c1 {
// loop.for %k = %c0 to %4 step %c1 {
-// %11 = linalg.load %arg0[%i, %j] :
+// %11 = load %arg0[%i, %j] :
// memref<?x?xf32, stride_specification>
-// %12 = linalg.load %arg1[%i, %j, %k] :
+// %12 = load %arg1[%i, %j, %k] :
// memref<?x?x?xf32, stride_specification>
-// %13 = linalg.load %arg2[%i, %k, %j] :
+// %13 = load %arg2[%i, %k, %j] :
// memref<?x?x?xf32, stride_specification>
// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
-// linalg.store %14#0, %arg1[%i, %j, %k] :
+// store %14#0, %arg1[%i, %j, %k] :
// memref<?x?x?Xf32, stride_specification>
-// linalg.store %14#1, %arg2[%i, %k, %j] :
+// store %14#1, %arg2[%i, %k, %j] :
// memref<?x?x?Xf32, stride_specification>
// }
// }
// }
// ```
-template <> class LinalgScopedEmitter<GenericOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, GenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
GenericOp genericOp,
}
};
-template <> class LinalgScopedEmitter<IndexedGenericOp> {
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
IndexedGenericOp genericOp,
}
};
-template <typename ConcreteOp>
+template <typename LoopType, typename IndexedValueType, typename ConcreteOp>
class LinalgRewritePattern : public RewritePattern {
public:
explicit LinalgRewritePattern(MLIRContext *context)
auto invertedMap =
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
if (!invertedMap) {
- LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation({}, linalgOp,
+ LinalgScopedEmitter<IndexedValueType,
+ ConcreteOp>::emitScalarImplementation({}, linalgOp,
&folder);
rewriter.eraseOp(op);
return matchSuccess();
auto nWin = linalgOp.getNumWindowLoops();
SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs);
- auto pivs = MutableArrayRef<ValueHandle *>(allPIvs).take_front(nPar);
- auto rivs = MutableArrayRef<ValueHandle *>(allPIvs)
- .take_front(nPar + nRed)
- .take_back(nRed);
- auto wivs = MutableArrayRef<ValueHandle *>(allPIvs).take_back(nWin);
-
auto loopRanges =
emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
getViewSizes(linalgOp), &folder);
- assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size());
-
- // clang-format off
- ArrayRef<Value *> ranges(loopRanges);
- LoopNestRangeBuilder(pivs, ranges.take_front(nPar))([&] {
- LoopNestRangeBuilder(rivs, ranges.drop_back(nWin).take_back(nRed))([&] {
- LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))(
- [&linalgOp, &allIvs, this] {
- auto allIvValues = extractValues(allIvs);
- LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation(
- allIvValues, linalgOp, &folder);
- });
- });
+ assert(loopRanges.size() == allIvs.size());
+
+ // clang-format off;
+ LoopNestRangeBuilder(allPIvs, loopRanges)([&] {
+ auto allIvValues = extractValues(allIvs);
+ LinalgScopedEmitter<IndexedValueType,
+ ConcreteOp>::emitScalarImplementation(allIvValues,
+ linalgOp,
+ &folder);
});
// clang-format on
rewriter.eraseOp(op);
};
// Helper classes for type list expansion.
-template <typename... LinalgOps> class ConversionList;
+template <typename LoopType, typename IndexedValueType, typename... LinalgOps>
+class ConversionList;
-template <> class ConversionList<> {
+template <typename LoopType, typename IndexedValueType>
+class ConversionList<LoopType, IndexedValueType> {
public:
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
};
-template <typename ConcreteOp, typename... LinalgOps>
-class ConversionList<ConcreteOp, LinalgOps...> {
+template <typename LoopType, typename IndexedValueType, typename ConcreteOp,
+ typename... LinalgOps>
+class ConversionList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> {
public:
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<LinalgRewritePattern<ConcreteOp>>(ctx);
- ConversionList<LinalgOps...>::build(patterns, ctx);
+ patterns
+ .insert<LinalgRewritePattern<LoopType, IndexedValueType, ConcreteOp>>(
+ ctx);
+ ConversionList<LoopType, IndexedValueType, LinalgOps...>::build(patterns,
+ ctx);
}
};
/// Populate the given list with patterns that convert from Linalg to LLVM.
-static void
-populateLinalgToLoopRewritePatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- ConversionList<
+template <typename LoopType, typename IndexedValueType>
+void ForOpRewritePatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ ConversionList<LoopType, IndexedValueType,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
- >::build(patterns, ctx);
+ >::build(patterns, ctx);
}
namespace {
-struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
+template <typename LoopType, typename IndexedValueType>
+struct LowerLinalgToLoopsPass
+ : public FunctionPass<LowerLinalgToLoopsPass<LoopType, IndexedValueType>> {
void runOnFunction() override;
};
} // namespace
-void LowerLinalgToLoopsPass::runOnFunction() {
+template <typename LoopType, typename IndexedValueType>
+void LowerLinalgToLoopsPass<LoopType, IndexedValueType>::runOnFunction() {
OwningRewritePatternList patterns;
- populateLinalgToLoopRewritePatterns(patterns, &getContext());
+ ForOpRewritePatterns<LoopType, IndexedValueType>(patterns,
+ &this->getContext());
- ConversionTarget target(getContext());
+ ConversionTarget target(this->getContext());
target.addLegalDialect<AffineOpsDialect>();
target.addLegalDialect<loop::LoopOpsDialect>();
target.addLegalDialect<StandardOpsDialect>();
- if (failed(applyPartialConversion(getFunction(), target, patterns))) {
- signalPassFailure();
+ if (failed(applyPartialConversion(this->getFunction(), target, patterns))) {
+ this->signalPassFailure();
}
}
std::unique_ptr<OpPassBase<FuncOp>>
mlir::linalg::createLowerLinalgToLoopsPass() {
- return std::make_unique<LowerLinalgToLoopsPass>();
+ return std::make_unique<
+ LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>();
}
-static PassRegistration<LowerLinalgToLoopsPass>
- pass("linalg-lower-to-loops",
- "Lower the operations from the linalg dialect into loops");
+static PassRegistration<LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>
+ structuredLoopsPass(
+ "linalg-lower-to-loops",
+ "Lower the operations from the linalg dialect into loops");