From c0f41e5bb3d615f35c7708992304bb989929cdb1 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 23 May 2019 07:00:18 -0700 Subject: [PATCH] Fix Linalg lowering to loops This CL makes lowering to loops always be a: ``` %D = linalg.dim %view, constant : !linalg.view<...> affine.for %ix = %c0 to %D { ... } ``` This form composes correctly with tiling and is also the proper way to emit loops from views that across function boundaries. The previous version that would extract the range_min/max/step was composing incorrectly with tiling (i.e. would shift by range_min both in the loop bounds and in the slice) and would not work across function boundaries. The relevant tests are updated and a new test `dot_view`---which lowers to loops from views passed as function parameters---is added. When additional context is available, the linalg.dim operations should be folded away but this is left for a future CL. -- PiperOrigin-RevId: 249634712 --- mlir/include/mlir/Linalg/IR/LinalgOps.h | 4 +++ mlir/include/mlir/Linalg/Utils/Utils.h | 10 +++++- mlir/lib/Linalg/Transforms/LowerToLoops.cpp | 44 +++++++++++++----------- mlir/lib/Linalg/Utils/Utils.cpp | 28 +++++++++------- mlir/test/Linalg/loops.mlir | 52 ++++++++++++++++++++--------- 5 files changed, 90 insertions(+), 48 deletions(-) diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index 7a106cf..26a5cf3 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -336,6 +336,10 @@ public: ArrayRef operands) { return impl->create(builder, loc, operands); } + Operation::operand_range getInputsAndOutputs() { + auto range = this->getOperation()->getOperands(); + return {range.begin(), range.begin() + getNumInputsAndOutputs()}; + } private: struct Concept { diff --git a/mlir/include/mlir/Linalg/Utils/Utils.h b/mlir/include/mlir/Linalg/Utils/Utils.h index ebdfe97..eea92be 100644 --- a/mlir/include/mlir/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Linalg/Utils/Utils.h @@ -89,8 +89,16 @@ Value *createOrReturnView(FuncBuilder *b, Location loc, enum class RangePart { Min = 0, Max, Step }; Value *extractRangePart(Value *range, RangePart part); +/// Returns the values obtained by applying `map` to the list of values. +/// Performs simplifications and foldings where possible. +SmallVector applyMapToValues(FuncBuilder *b, Location loc, + AffineMap map, + ArrayRef values, + FunctionConstants &state); + /// Returns the values obtained by applying `map` to the list of range parts -/// extracted from `ranges`. +/// extracted from `ranges`. Performs simplifications and foldings where +/// possible. SmallVector applyMapToRangePart(FuncBuilder *b, Location loc, AffineMap map, ArrayRef ranges, diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp index 4691d1f..b248578 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp @@ -24,49 +24,55 @@ #include "mlir/Linalg/Passes.h" #include "mlir/Linalg/Utils/Utils.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" - using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; -using namespace llvm; // Creates a number of ranges equal to the number of results in `map`. // The returned ranges correspond to the loop ranges, in the proper order, for // which new loops will be created. -static SmallVector makeLoopRanges(FuncBuilder *b, Location loc, +static SmallVector emitLoopRanges(FuncBuilder *b, Location loc, AffineMap map, - ArrayRef allOpRanges, + ArrayRef allViewSizes, FunctionConstants &state) { - // Apply `map` to get mins/maxes/steps in loop order. - auto mins = - applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Min, state); - auto maxes = - applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Max, state); - auto steps = - applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Step, state); - + // Apply `map` to get view sizes in loop order. + auto sizes = applyMapToValues(b, loc, map, allViewSizes, state); // Create a new range with the applied tile sizes. SmallVector res; - for (unsigned idx = 0, e = steps.size(); idx < e; ++idx) - res.push_back(b->create(loc, mins[idx], maxes[idx], steps[idx])); + for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { + res.push_back(b->create(loc, state.getOrCreateIndex(0), sizes[idx], + state.getOrCreateIndex(1))); + } + return res; +} + +// Returns the linearized list of all view dimensions in a linalgOp. Appliying +// the inverse, concatenated loopToOperandRangeMaps to this list allows the +// derivation of loop ranges for any linalgOp. +static SmallVector getViewSizes(LinalgOp &linalgOp) { + SmallVector res; + using dim = ValueBuilder; + for (auto v : linalgOp.getInputsAndOutputs()) { + ViewType t = v->getType().cast(); + for (unsigned i = 0; i < t.getRank(); ++i) + res.push_back(dim(v, i)); + } return res; } static void emitLinalgOpAsLoops(LinalgOp &linalgOp, FunctionConstants &state) { FuncBuilder b(linalgOp.getOperation()); ScopedContext scope(b, linalgOp.getOperation()->getLoc()); - auto loopRanges = makeLoopRanges( + auto loopRanges = emitLoopRanges( scope.getBuilder(), scope.getLocation(), // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))), - getRanges(linalgOp.getOperation()), state); + getViewSizes(linalgOp), state); SmallVector parallelIvs(linalgOp.getNumParallelLoops()); SmallVector reductionIvs(linalgOp.getNumReductionLoops()); diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index 00591c6..9d4e2c9 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -167,16 +167,10 @@ static Value *emitOrFoldComposedAffineApply(FuncBuilder *b, Location loc, return b->create(loc, map, operands); } -SmallVector mlir::applyMapToRangePart(FuncBuilder *b, Location loc, - AffineMap map, - ArrayRef ranges, - RangePart part, - FunctionConstants &state) { - SmallVector rangeParts(ranges.size()); - - llvm::transform(ranges, rangeParts.begin(), - [&](Value *range) { return extractRangePart(range, part); }); - +SmallVector mlir::applyMapToValues(FuncBuilder *b, Location loc, + AffineMap map, + ArrayRef values, + FunctionConstants &state) { SmallVector res; res.reserve(map.getNumResults()); unsigned numDims = map.getNumDims(); @@ -185,12 +179,22 @@ SmallVector mlir::applyMapToRangePart(FuncBuilder *b, Location loc, // folding occurs eagerly. Otherwise, an affine.apply operation is emitted. for (auto expr : map.getResults()) { AffineMap map = AffineMap::get(numDims, 0, expr, {}); - res.push_back( - emitOrFoldComposedAffineApply(b, loc, map, rangeParts, state)); + res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, state)); } return res; } +SmallVector mlir::applyMapToRangePart(FuncBuilder *b, Location loc, + AffineMap map, + ArrayRef ranges, + RangePart part, + FunctionConstants &state) { + SmallVector rangeParts(ranges.size()); + llvm::transform(ranges, rangeParts.begin(), + [&](Value *range) { return extractRangePart(range, part); }); + return applyMapToValues(b, loc, map, rangeParts, state); +} + Value *FunctionConstants::getOrCreateIndex(int64_t v) { auto it = map.find(v); if (it != map.end()) diff --git a/mlir/test/Linalg/loops.mlir b/mlir/test/Linalg/loops.mlir index cdedde5..5e9246f 100644 --- a/mlir/test/Linalg/loops.mlir +++ b/mlir/test/Linalg/loops.mlir @@ -18,14 +18,17 @@ func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: inde // CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view -// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%arg1) { -// CHECK: affine.for %i1 = #[[ID]](%c0) to #[[ID]](%arg2) { -// CHECK: affine.for %i2 = #[[ID]](%c0) to #[[ID]](%arg3) { +// CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view +// CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view +// CHECK: %[[N:.*]] = linalg.dim %[[B]], 1 : !linalg.view +// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[M]]) { +// CHECK: affine.for %i1 = #[[ID]](%c0) to #[[ID]](%[[N]]) { +// CHECK: affine.for %i2 = #[[ID]](%c0) to #[[ID]](%[[K]]) { // CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%i0, %i2] : !linalg.view // CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%i2, %i1] : !linalg.view -// CHECK: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 -// CHECK: %[[c:.*]] = linalg.load %[[C]][%i0, %i1] : !linalg.view -// CHECK: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 +// CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 +// CHECK-DAG: %[[c:.*]] = linalg.load %[[C]][%i0, %i1] : !linalg.view +// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 // CHECK: linalg.store %[[res]], %[[C]][%i0, %i1] : !linalg.view func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { @@ -43,13 +46,15 @@ func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: inde // CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view -// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%arg1) { -// CHECK: affine.for %i1 = #[[ID]](%c0) to #[[ID]](%arg2) { +// CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view +// CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view +// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[M]]) { +// CHECK: affine.for %i1 = #[[ID]](%c0) to #[[ID]](%[[K]]) { // CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%i0, %i1] : !linalg.view // CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%i1] : !linalg.view -// CHECK: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 -// CHECK: %[[c:.*]] = linalg.load %[[C]][%i0] : !linalg.view -// CHECK: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 +// CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 +// CHECK-DAG: %[[c:.*]] = linalg.load %[[C]][%i0] : !linalg.view +// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 // CHECK: linalg.store %[[res]], %[[C]][%i0] : !linalg.view func @dot(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { @@ -66,10 +71,25 @@ func @dot(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) // CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.view -// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%arg1) { +// CHECK: %[[K:.*]] = linalg.dim %[[A]], 0 : !linalg.view +// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[K]]) { // CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%i0] : !linalg.view // CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%i0] : !linalg.view -// CHECK: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 -// CHECK: %[[c:.*]] = linalg.load %[[C]][] : !linalg.view -// CHECK: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 -// CHECK: linalg.store %[[res]], %[[C]][] : !linalg.view \ No newline at end of file +// CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 +// CHECK-DAG: %[[c:.*]] = linalg.load %[[C]][] : !linalg.view +// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 +// CHECK: linalg.store %[[res]], %[[C]][] : !linalg.view + +func @dot_view(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { + linalg.dot(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view + return +} +// CHECK-LABEL: func @dot_view(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { +// CHECK: %[[K:.*]] = linalg.dim %arg0, 0 : !linalg.view +// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[K]]) { +// CHECK-DAG: %[[a:.*]] = linalg.load %arg0[%i0] : !linalg.view +// CHECK-DAG: %[[b:.*]] = linalg.load %arg1[%i0] : !linalg.view +// CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 +// CHECK-DAG: %[[c:.*]] = linalg.load %arg2[] : !linalg.view +// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 +// CHECK: linalg.store %[[res]], %arg2[] : !linalg.view -- 2.7.4