From 24647750d487c2c496bd996d15ddaa7af090ef73 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 6 Aug 2019 05:37:47 -0700 Subject: [PATCH] Refactor Linalg ops to loop lowering (NFC) This CL modifies the LowerLinalgToLoopsPass to use RewritePattern. This will make it easier to inline Linalg generic functions and regions when emitting to loops in a subsequent CL. PiperOrigin-RevId: 261894120 --- .../Linalg/Linalg3/lib/Transforms.cpp | 4 +- .../Linalg/Linalg4/lib/Transforms.cpp | 2 +- .../Tutorials/Linalg/DeclarativeBuilders.md | 4 +- mlir/include/mlir/EDSC/Intrinsics.h | 34 +- mlir/include/mlir/Linalg/IR/LinalgOps.h | 5 - mlir/include/mlir/Linalg/Utils/Intrinsics.h | 4 + mlir/include/mlir/Linalg/Utils/Utils.h | 12 +- mlir/lib/Linalg/CMakeLists.txt | 9 +- mlir/lib/Linalg/IR/LinalgOps.cpp | 178 --------- mlir/lib/Linalg/Transforms/LowerToLoops.cpp | 348 +++++++++++++++--- mlir/lib/Linalg/Transforms/Tiling.cpp | 2 +- mlir/lib/Linalg/Utils/Utils.cpp | 10 - mlir/lib/Transforms/LowerVectorTransfers.cpp | 11 +- mlir/test/mlir-tblgen/reference-impl.td | 4 +- 14 files changed, 358 insertions(+), 269 deletions(-) diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 8f97f4317f71..9dd8a4227ae8 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -169,8 +169,8 @@ writeContractionAsLoops(ContractionOp contraction) { SmallVector parallelIvs(contraction.getNumParallelDims()); SmallVector reductionIvs(contraction.getNumReductionDims()); - auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs); - auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs); + auto pivs = makeIndexHandlePointers(parallelIvs); + auto rivs = makeIndexHandlePointers(reductionIvs); assert(loopRanges.size() == pivs.size() + rivs.size()); // clang-format off diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index fcb4c20704d1..15e544a773cf 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -150,7 +150,7 @@ writeContractionAsTiledViews(TensorContractionBase &contraction, mlir::OpBuilder builder(op->getOperation()); ScopedContext scope(builder, op->getLoc()); SmallVector ivs(tileSizes.size()); - auto pivs = IndexHandle::makeIndexHandlePointers(ivs); + auto pivs = makeIndexHandlePointers(ivs); // clang-format off using linalg::common::LoopNestRangeBuilder; diff --git a/mlir/g3doc/Tutorials/Linalg/DeclarativeBuilders.md b/mlir/g3doc/Tutorials/Linalg/DeclarativeBuilders.md index 2069e9af276f..54741121c835 100644 --- a/mlir/g3doc/Tutorials/Linalg/DeclarativeBuilders.md +++ b/mlir/g3doc/Tutorials/Linalg/DeclarativeBuilders.md @@ -83,8 +83,8 @@ def AddOp : Op<"x.add">, Arguments<(ins Tensor:$A, Tensor:$B)>, Results<(outs Tensor: $C)> { code referenceImplementation = [{ - auto ivs = IndexHandle::makeIndexHandles(view_A.rank()); - auto pivs = IndexHandle::makePIndexHandles(ivs); + auto ivs = makeIndexHandles(view_A.rank()); + auto pivs = makePIndexHandles(ivs); IndexedValue A(arg_A), B(arg_B), C(arg_C); LoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())( [&]{ diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index 021fec2f444b..6870e029ce89 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -61,19 +61,31 @@ struct IndexHandle : public ValueHandle { this->v = v.getValue(); return *this; } - static SmallVector makeIndexHandles(unsigned rank) { - return SmallVector(rank); +}; + +inline SmallVector makeIndexHandles(unsigned rank) { + return SmallVector(rank); +} + +inline SmallVector +makeIndexHandlePointers(MutableArrayRef ivs) { + SmallVector pivs; + pivs.reserve(ivs.size()); + for (auto &iv : ivs) { + pivs.push_back(&iv); } - static SmallVector - makeIndexHandlePointers(SmallVectorImpl &ivs) { - SmallVector pivs; - pivs.reserve(ivs.size()); - for (auto &iv : ivs) { - pivs.push_back(&iv); - } - return pivs; + return pivs; +} + +/// Returns a vector of the underlying Value* from `ivs`. +inline SmallVector extractValues(ArrayRef ivs) { + SmallVector vals; + vals.reserve(ivs.size()); + for (auto &iv : ivs) { + vals.push_back(iv.getValue()); } -}; + return vals; +} /// Provides a set of first class intrinsics. /// In the future, most of intrinsics related to Operation that don't contain diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index 41767ad6f902..511f8035d725 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -436,11 +436,6 @@ private: }; }; -void emitScalarImplementation(llvm::ArrayRef parallelIvs, - llvm::ArrayRef reductionIvs, - llvm::ArrayRef windowIvs, - LinalgOp &linalgOp, OperationFolder &folder); - } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Linalg/Utils/Intrinsics.h b/mlir/include/mlir/Linalg/Utils/Intrinsics.h index c7f3d91282ab..eabec69883e8 100644 --- a/mlir/include/mlir/Linalg/Utils/Intrinsics.h +++ b/mlir/include/mlir/Linalg/Utils/Intrinsics.h @@ -27,8 +27,10 @@ class BufferDeallocOp; class CopyOp; class DimOp; class FillOp; +class LoadOp; class RangeOp; class SliceOp; +class StoreOp; class ViewOp; namespace intrinsics { using buffer_alloc = mlir::edsc::intrinsics::ValueBuilder; @@ -37,6 +39,8 @@ using buffer_dealloc = using copy = mlir::edsc::intrinsics::OperationBuilder; using dim = mlir::edsc::intrinsics::ValueBuilder; using fill = mlir::edsc::intrinsics::OperationBuilder; +using linalg_load = mlir::edsc::intrinsics::ValueBuilder; +using linalg_store = mlir::edsc::intrinsics::OperationBuilder; using range = mlir::edsc::intrinsics::ValueBuilder; using slice = mlir::edsc::intrinsics::ValueBuilder; using view = mlir::edsc::intrinsics::ValueBuilder; diff --git a/mlir/include/mlir/Linalg/Utils/Utils.h b/mlir/include/mlir/Linalg/Utils/Utils.h index 1c0335985d74..68d71a8d37c5 100644 --- a/mlir/include/mlir/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Linalg/Utils/Utils.h @@ -21,6 +21,7 @@ #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/Linalg/IR/LinalgOps.h" +#include "mlir/Linalg/Utils/Intrinsics.h" #include "mlir/Support/LLVM.h" namespace mlir { @@ -79,7 +80,16 @@ namespace linalg { /// Returns the linearized list of all view dimensions in a linalgOp. Applying /// the inverse, concatenated loopToOperandRangeMaps to this list allows the /// derivation of loop ranges for any linalgOp. -SmallVector getViewSizes(LinalgOp &linalgOp); +template +SmallVector getViewSizes(ConcreteOp linalgOp) { + SmallVector res; + for (auto v : linalgOp.getInputsAndOutputs()) { + ViewType t = v->getType().template cast(); + for (unsigned i = 0; i < t.getRank(); ++i) + res.push_back(intrinsics::dim(v, i)); + } + return res; +} /// Returns the values obtained by applying `map` to the list of values. /// Performs simplifications and foldings where possible. diff --git a/mlir/lib/Linalg/CMakeLists.txt b/mlir/lib/Linalg/CMakeLists.txt index d015940e3c00..b37bdaac4401 100644 --- a/mlir/lib/Linalg/CMakeLists.txt +++ b/mlir/lib/Linalg/CMakeLists.txt @@ -14,4 +14,11 @@ add_llvm_library(MLIRLinalg DEPENDS intrinsics_gen ) -add_dependencies(MLIRLinalg MLIRLinalgOpsIncGen MLIRLinalgLibraryOpsIncGen MLIRStandardToLLVM) + +add_dependencies(MLIRLinalg + + MLIRAffineOps + MLIRLinalgOpsIncGen + MLIRLinalgLibraryOpsIncGen + MLIRStandardToLLVM + ) diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 59bddd302ecc..f56470a69142 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -846,23 +846,6 @@ static SmallVector concat(ArrayRef a, return res; } -static SmallVector -foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, - ArrayRef vals, OperationFolder &folder) { - assert(map.getNumSymbols() == 0); - assert(map.getNumInputs() == vals.size()); - SmallVector res; - res.reserve(map.getNumResults()); - auto dims = map.getNumDims(); - for (auto e : map.getResults()) { - auto exprMap = AffineMap::get(dims, 0, e); - SmallVector operands(vals.begin(), vals.end()); - canonicalizeMapAndOperands(&exprMap, &operands); - res.push_back(affine_apply(folder, exprMap, operands)); - } - return res; -} - // Note: both functions below would completely disappear with a simple tensor // kernel language. // @@ -950,164 +933,3 @@ SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { } llvm_unreachable("Missing loopToOperandRangesMaps for op"); } - -static SmallVector permuteIvs(ArrayRef ivs, - Optional permutation, - OperationFolder &state) { - return permutation ? applyMapToValues(ScopedContext::getBuilder(), - ScopedContext::getLocation(), - permutation.getValue(), ivs, state) - : SmallVector(ivs.begin(), ivs.end()); -} - -// Ideally this should all be Tablegen'd but there is no good story for op -// expansion directly in MLIR for now. -void mlir::linalg::emitScalarImplementation( - llvm::ArrayRef parallelIvs, llvm::ArrayRef reductionIvs, - llvm::ArrayRef windowIvs, LinalgOp &linalgOp, - OperationFolder &folder) { - using linalg_load = ValueBuilder; - using linalg_store = OperationBuilder; - using IndexedValue = TemplatedIndexedValue; - using edsc::op::operator+; - using edsc::op::operator*; - using edsc::op::operator==; - using edsc::intrinsics::select; - - auto nPar = parallelIvs.size(); - auto nRed = reductionIvs.size(); - auto nWin = windowIvs.size(); - SmallVector allIvs; - allIvs.reserve(nPar + nRed + nWin); - allIvs.assign(parallelIvs.begin(), parallelIvs.end()); - allIvs.append(reductionIvs.begin(), reductionIvs.end()); - allIvs.append(windowIvs.begin(), windowIvs.end()); - - // Default OpBuilder supports 0-D case (no loops). - OpBuilder b(linalgOp.getOperation()); - auto nLoops = nPar + nRed + nWin; - if (nLoops > 0) { - auto innermostLoop = loop::getForInductionVarOwner(allIvs.back()); - // accounts for linalg.terminator in loop. - b = innermostLoop.getBodyBuilder(); - } - - auto loc = linalgOp.getLoc(); - ScopedContext scope(b, loc); - auto *op = linalgOp.getOperation(); - if (auto copyOp = dyn_cast(op)) { - OperationFolder state; - auto inputIvs = permuteIvs(parallelIvs, copyOp.inputPermutation(), state); - auto outputIvs = permuteIvs(parallelIvs, copyOp.outputPermutation(), state); - SmallVector iivs(inputIvs.begin(), inputIvs.end()); - SmallVector oivs(outputIvs.begin(), outputIvs.end()); - // clang-format off - IndexedValue O(copyOp.getOutput(0)), I(copyOp.getInput(0)); - nLoops > 0 ? - O(oivs) = I(iivs) : - O() = I(); - // clang-format on - return; - } - if (auto fillOp = dyn_cast(op)) { - SmallVector ivs(parallelIvs.begin(), parallelIvs.end()); - // clang-format off - IndexedValue O(fillOp.getOutput(0)); - nLoops > 0 ? - O(ivs) = ValueHandle(fillOp.getValue()) : - O() = ValueHandle(fillOp.getValue()); - // clang-format on - return; - } - if (auto dotOp = dyn_cast(op)) { - IndexHandle r_i(reductionIvs[0]); - IndexedValue A(dotOp.getInput(0)), B(dotOp.getInput(1)), - C(dotOp.getOutput(0)); - C() = C() + A(r_i) * B(r_i); - return; - } - if (auto matvecOp = dyn_cast(op)) { - IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]); - IndexedValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), - C(matvecOp.getOutput(0)); - C(i) = C(i) + A(i, r_j) * B(r_j); - return; - } - if (auto matmulOp = dyn_cast(op)) { - IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]); - IndexedValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), - C(matmulOp.getOutput(0)); - C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); - return; - } - if (auto convOp = dyn_cast(op)) { - auto maps = loopToOperandRangesMaps(op); - SmallVector fIdx( - foldedAffineApplies(b, loc, maps[0], allIvs, folder)); - SmallVector imIdx( - foldedAffineApplies(b, loc, maps[1], allIvs, folder)); - SmallVector oIdx( - foldedAffineApplies(b, loc, maps[2], allIvs, folder)); - IndexedValue F(convOp.filter()), I(convOp.input()), O(convOp.output()); - O(oIdx) += F(fIdx) * I(imIdx); - return; - } - if (auto genericOp = dyn_cast(op)) { - using edsc::intrinsics::detail::ValueHandleArray; - unsigned nInputs = genericOp.getNumInputs(); - unsigned nOutputs = genericOp.getNumOutputs(); - SmallVector indexedValues(nInputs + nOutputs); - // Emits the MLIR for the scalar part of the generic op by: - // 1. Emitting linalg_load and linalg_store ops for each input and output - // view in order. This is achieved by applying the appropriate input or - // output map to the enclosing induction variables. - // 2. Emitting a call to `op.fun()` that takes as arguments the scalars - // from point 1. above. - // 3. Emitting linalg_store to store the results of 2. to the output - // views. - // - // An example output may resemble: - // - // ``` - // loop.for %i = %c0 to %0 step %c1 { - // loop.for %j = %c0 to %1 step %c1 { - // loop.for %k = %c0 to %4 step %c1 { - // %11 = linalg.load %arg0[%i, %j] : !linalg.view - // %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view - // %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view - // %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) - // linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view - // linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view - // } - // } - // } - // ``` - - // 1.a. Emit linalg_load from input views. - for (unsigned i = 0, e = nInputs; i < e; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs, folder)); - indexedValues[i] = linalg_load(genericOp.getInput(i), indexing); - } - // 1.b. Emit linalg_load from output views.. - for (unsigned i = 0, e = nOutputs; i < e; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); - indexedValues[nInputs + i] = - linalg_load(genericOp.getOutput(i), indexing); - } - // 2. Emit call. - auto m = genericOp.getParentOfType(); - auto fun = m.lookupSymbol(genericOp.fun()); - Operation *callOp = call(fun, indexedValues); - assert(callOp->getNumResults() == genericOp.getNumOutputs()); - // 3. Emit linalg_store. - for (unsigned i = 0, e = nOutputs; i < e; ++i) { - ValueHandleArray indexing(foldedAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); - linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing); - } - return; - } - llvm_unreachable("Missing emitScalarImplementation for op"); -} diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp index 2e616c35f1d8..c75ee48aac11 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp @@ -15,6 +15,8 @@ // limitations under the License. // ============================================================================= +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -22,17 +24,50 @@ #include "mlir/Linalg/IR/LinalgOps.h" #include "mlir/Linalg/IR/LinalgTypes.h" #include "mlir/Linalg/Passes.h" +#include "mlir/Linalg/Utils/Intrinsics.h" #include "mlir/Linalg/Utils/Utils.h" #include "mlir/Pass/Pass.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; +using namespace mlir::linalg::intrinsics; + +using IndexedLinalgValue = TemplatedIndexedValue; +using edsc::op::operator+; +using edsc::op::operator==; + +static SmallVector +foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, + ArrayRef vals, OperationFolder &folder) { + assert(map.getNumSymbols() == 0); + assert(map.getNumInputs() == vals.size()); + SmallVector res; + res.reserve(map.getNumResults()); + auto dims = map.getNumDims(); + for (auto e : map.getResults()) { + auto exprMap = AffineMap::get(dims, 0, e); + SmallVector operands(vals.begin(), vals.end()); + canonicalizeMapAndOperands(&exprMap, &operands); + res.push_back(affine_apply(folder, exprMap, operands)); + } + return res; +} + +static SmallVector permuteIvs(ArrayRef ivs, + Optional permutation, + OperationFolder &state) { + return permutation ? applyMapToValues(ScopedContext::getBuilder(), + ScopedContext::getLocation(), + permutation.getValue(), ivs, state) + : SmallVector(ivs.begin(), ivs.end()); +} // Creates a number of ranges equal to the number of results in `map`. // The returned ranges correspond to the loop ranges, in the proper order, for @@ -40,61 +75,272 @@ using namespace mlir::linalg; static SmallVector emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, ArrayRef allViewSizes, - OperationFolder &state) { + OperationFolder &folder) { // Apply `map` to get view sizes in loop order. - auto sizes = applyMapToValues(b, loc, map, allViewSizes, state); + auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder); // Create a new range with the applied tile sizes. + ScopedContext scope(b, loc); SmallVector res; for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { - res.push_back(b.create( - loc, state.create(b, loc, 0), sizes[idx], - state.create(b, loc, 1))); + res.push_back(range(constant_index(folder, 0), sizes[idx], + constant_index(folder, 1))); } return res; } -static void emitLinalgOpAsLoops(LinalgOp &linalgOp, OperationFolder &state) { - OpBuilder b(linalgOp.getOperation()); - ScopedContext scope(b, linalgOp.getOperation()->getLoc()); - // The flattened loopToOperandRangesMaps is expected to be an invertible - // permutation map (which is asserted in the inverse calculation). - auto invertedMap = - inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); - if (!invertedMap) { - mlir::linalg::emitScalarImplementation({}, {}, {}, linalgOp, state); - return; +template class LinalgScopedEmitter {}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp, + OperationFolder &folder) { + auto nPar = copyOp.getNumParallelLoops(); + assert(nPar == allIvs.size()); + auto inputIvs = + permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder); + auto outputIvs = + permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder); + SmallVector iivs(inputIvs.begin(), inputIvs.end()); + SmallVector oivs(outputIvs.begin(), outputIvs.end()); + IndexedLinalgValue O(copyOp.getOutput(0)), I(copyOp.getInput(0)); + // Emit the proper scalar assignment, whether we are dealing with a 0-D or + // an n-D loop nest; with or without permutations. + // clang-format off + nPar > 0 ? O(oivs) = I(iivs) : + O() = I(); + // clang-format on + } +}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp, + OperationFolder &folder) { + auto nPar = fillOp.getNumParallelLoops(); + assert(nPar == allIvs.size()); + auto ivs = + SmallVector(allIvs.begin(), allIvs.begin() + nPar); + IndexedLinalgValue O(fillOp.getOutput(0)); + // Emit the proper scalar assignment, whether we are dealing with a 0-D or + // an n-D loop nest; with or without permutations. + nPar > 0 ? O(ivs) = ValueHandle(fillOp.getValue()) + : O() = ValueHandle(fillOp.getValue()); + } +}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp, + OperationFolder &folder) { + assert(allIvs.size() == 1); + IndexHandle r_i(allIvs[0]); + IndexedLinalgValue A(dotOp.getInput(0)), B(dotOp.getInput(1)), + C(dotOp.getOutput(0)); + // Emit scalar form. + C() = C() + A(r_i) * B(r_i); } +}; - auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(), - invertedMap, getViewSizes(linalgOp), state); - - SmallVector parallelIvs(linalgOp.getNumParallelLoops()); - SmallVector reductionIvs(linalgOp.getNumReductionLoops()); - SmallVector windowIvs(linalgOp.getNumWindowLoops()); - auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs); - auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs); - auto wivs = IndexHandle::makeIndexHandlePointers(windowIvs); - assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size()); - - // clang-format off - ArrayRef ranges(loopRanges); - LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&] { - LoopNestRangeBuilder( - rivs, ranges.drop_back(wivs.size()).take_back(rivs.size()))([&] { - LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))( - [&linalgOp, ¶llelIvs, &reductionIvs, &windowIvs, &state] { - SmallVector parallel( - parallelIvs.begin(), parallelIvs.end()); - SmallVector reduction( - reductionIvs.begin(), reductionIvs.end()); - SmallVector window( - windowIvs.begin(), windowIvs.end()); - mlir::linalg::emitScalarImplementation( - parallel, reduction, window, linalgOp, state); +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + MatvecOp matvecOp, + OperationFolder &folder) { + assert(allIvs.size() == 2); + IndexHandle i(allIvs[0]), r_j(allIvs[1]); + IndexedLinalgValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), + C(matvecOp.getOutput(0)); + // Emit scalar form. + C(i) = C(i) + A(i, r_j) * B(r_j); + } +}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + MatmulOp matmulOp, + OperationFolder &folder) { + assert(allIvs.size() == 3); + IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); + IndexedLinalgValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), + C(matmulOp.getOutput(0)); + // Emit scalar form. + C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); + } +}; + +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp, + OperationFolder &folder) { + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + auto maps = loopToOperandRangesMaps(convOp); + SmallVector fIdx( + foldedAffineApplies(b, loc, maps[0], allIvs, folder)); + SmallVector imIdx( + foldedAffineApplies(b, loc, maps[1], allIvs, folder)); + SmallVector oIdx( + foldedAffineApplies(b, loc, maps[2], allIvs, folder)); + IndexedLinalgValue F(convOp.filter()), I(convOp.input()), + O(convOp.output()); + // Emit scalar form. + O(oIdx) += F(fIdx) * I(imIdx); + } +}; + +// Emits the MLIR for the scalar part of the generic op by: +// 1. Emitting linalg_load and linalg_store ops for each input and output +// view in order. This is achieved by applying the appropriate input or +// output map to the enclosing induction variables. +// 2. Emitting a call to `op.fun()` that takes as arguments the scalars +// from point 1. above. +// 3. Emitting linalg_store to store the results of 2. to the output +// views. +// +// An example output may resemble: +// +// ``` +// loop.for %i = %c0 to %0 step %c1 { +// loop.for %j = %c0 to %1 step %c1 { +// loop.for %k = %c0 to %4 step %c1 { +// %11 = linalg.load %arg0[%i, %j] : !linalg.view +// %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view +// %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view +// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) +// linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view +// linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view +// } +// } +// } +// ``` +template <> class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + GenericOp genericOp, + OperationFolder &folder) { + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + using edsc::intrinsics::detail::ValueHandleArray; + unsigned nInputs = genericOp.getNumInputs(); + unsigned nOutputs = genericOp.getNumOutputs(); + SmallVector indexedValues(nInputs + nOutputs); + + // 1.a. Emit linalg_load from input views. + for (unsigned i = 0, e = nInputs; i < e; ++i) { + ValueHandleArray indexing(foldedAffineApplies( + b, loc, genericOp.getInputIndexingMap(i), allIvs, folder)); + indexedValues[i] = linalg_load(genericOp.getInput(i), indexing); + } + + // 1.b. Emit linalg_load from output views. + for (unsigned i = 0, e = nOutputs; i < e; ++i) { + ValueHandleArray indexing(foldedAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + indexedValues[nInputs + i] = + linalg_load(genericOp.getOutput(i), indexing); + } + + // 2. Emit call. + auto m = genericOp.getParentOfType(); + auto fun = m.lookupSymbol(genericOp.fun()); + Operation *callOp = call(fun, indexedValues); + assert(callOp->getNumResults() == genericOp.getNumOutputs()); + + // 3. Emit linalg_store. + for (unsigned i = 0, e = nOutputs; i < e; ++i) { + ValueHandleArray indexing(foldedAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); + linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing); + } + } +}; + +template +class LinalgRewritePattern : public RewritePattern { +public: + explicit LinalgRewritePattern(MLIRContext *context) + : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context) { + } + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + OpBuilder b(op); + ScopedContext scope(b, op->getLoc()); + + // The flattened loopToOperandRangesMaps is expected to be an invertible + // permutation map (which is asserted in the inverse calculation). + auto linalgOp = cast(op); + auto invertedMap = + inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); + if (!invertedMap) { + LinalgScopedEmitter::emitScalarImplementation({}, linalgOp, + folder); + rewriter.replaceOp(op, {}); + return matchSuccess(); + } + + auto nPar = linalgOp.getNumParallelLoops(); + auto nRed = linalgOp.getNumReductionLoops(); + auto nWin = linalgOp.getNumWindowLoops(); + SmallVector allIvs(nPar + nRed + nWin); + SmallVector allPIvs = makeIndexHandlePointers(allIvs); + auto pivs = MutableArrayRef(allPIvs).take_front(nPar); + auto rivs = MutableArrayRef(allPIvs) + .take_front(nPar + nRed) + .take_back(nRed); + auto wivs = MutableArrayRef(allPIvs).take_back(nWin); + + auto loopRanges = + emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap, + getViewSizes(linalgOp), folder); + assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size()); + + // clang-format off + ArrayRef ranges(loopRanges); + LoopNestRangeBuilder(pivs, ranges.take_front(nPar))([&] { + LoopNestRangeBuilder(rivs, ranges.drop_back(nWin).take_back(nRed))([&] { + LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))( + [&linalgOp, &allIvs, this] { + auto allIvValues = extractValues(allIvs); + LinalgScopedEmitter::emitScalarImplementation( + allIvValues, linalgOp, folder); + }); }); }); - }); - // clang-format on + // clang-format on + rewriter.replaceOp(op, {}); + return matchSuccess(); + } + + mutable OperationFolder folder; +}; + +// Helper classes for type list expansion. +template class ConversionList; + +template <> class ConversionList<> { +public: + static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {} +}; + +template +class ConversionList { +public: + static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert>(ctx); + ConversionList::build(patterns, ctx); + } +}; + +/// Populate the given list with patterns that convert from Linalg to LLVM. +static void +populateLinalgToLoopRewritePatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + ConversionList< +#define GET_OP_LIST +#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc" + >::build(patterns, ctx); } namespace { @@ -104,11 +350,17 @@ struct LowerLinalgToLoopsPass : public FunctionPass { } // namespace void LowerLinalgToLoopsPass::runOnFunction() { - OperationFolder state; - getFunction().walk([&state](LinalgOp linalgOp) { - emitLinalgOpAsLoops(linalgOp, state); - linalgOp.getOperation()->erase(); - }); + OwningRewritePatternList patterns; + populateLinalgToLoopRewritePatterns(patterns, &getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) { + signalPassFailure(); + } } FunctionPassBase *mlir::linalg::createLowerLinalgToLoopsPass() { diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 25ffdebc61ab..8090a587d426 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -381,7 +381,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef tileSizes, // 3. Create the tiled loops. LinalgOp res = op; SmallVector ivs(loopRanges.size()); - auto pivs = IndexHandle::makeIndexHandlePointers(ivs); + auto pivs = makeIndexHandlePointers(ivs); LoopNestRangeBuilder(pivs, loopRanges)([&] { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index 850aefe0eae4..d31fe0d30063 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -106,16 +106,6 @@ ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( return ValueHandle::null(); } -SmallVector mlir::linalg::getViewSizes(LinalgOp &linalgOp) { - SmallVector res; - for (auto v : linalgOp.getInputsAndOutputs()) { - ViewType t = v->getType().cast(); - for (unsigned i = 0; i < t.getRank(); ++i) - res.push_back(linalg::intrinsics::dim(v, i)); - } - return res; -} - static Value *emitOrFoldComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef operandsRef, diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index ef67488023f9..cda62d9ddc07 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -273,10 +273,9 @@ VectorTransferRewriter::matchAndRewrite( IndexedValue remote(transfer.getMemRef()); MemRefView view(transfer.getMemRef()); VectorView vectorView(transfer.getVector()); - SmallVector ivs = - IndexHandle::makeIndexHandles(vectorView.rank()); + SmallVector ivs = makeIndexHandles(vectorView.rank()); SmallVector pivs = - IndexHandle::makeIndexHandlePointers(ivs); + makeIndexHandlePointers(MutableArrayRef(ivs)); coalesceCopy(transfer, &pivs, &vectorView); auto lbs = vectorView.getLbs(); @@ -335,10 +334,8 @@ VectorTransferRewriter::matchAndRewrite( MemRefView view(transfer.getMemRef()); ValueHandle vectorValue(transfer.getVector()); VectorView vectorView(transfer.getVector()); - SmallVector ivs = - IndexHandle::makeIndexHandles(vectorView.rank()); - SmallVector pivs = - IndexHandle::makeIndexHandlePointers(ivs); + SmallVector ivs = makeIndexHandles(vectorView.rank()); + SmallVector pivs = makeIndexHandlePointers(ivs); coalesceCopy(transfer, &pivs, &vectorView); auto lbs = vectorView.getLbs(); diff --git a/mlir/test/mlir-tblgen/reference-impl.td b/mlir/test/mlir-tblgen/reference-impl.td index 69b1787c89bc..07baa8791e39 100644 --- a/mlir/test/mlir-tblgen/reference-impl.td +++ b/mlir/test/mlir-tblgen/reference-impl.td @@ -16,8 +16,8 @@ def X_AddOp : X_Op<"add">, Results<(outs AnyTensor: $C)> { // TODO: extract referenceImplementation to Op. code referenceImplementation = [{ - auto ivs = IndexHandle::makeIndexHandles(view_A.rank()); - auto pivs = IndexHandle::makeIndexHandlePointers(ivs); + auto ivs = IndexedLinalgValuemakeIndexHandles(view_A.rank()); + auto pivs = IndexedLinalgValuemakeIndexHandlePointers(ivs); IndexedValue A(arg_A), B(arg_B), C(arg_C); LoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())({ C(ivs) = A(ivs) + B(ivs) -- 2.34.1