Fix Linalg lowering to loops
authorNicolas Vasilache <ntv@google.com>
Thu, 23 May 2019 14:00:18 +0000 (07:00 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:57:54 +0000 (19:57 -0700)
    This CL makes lowering to loops always be a:
    ```
    %D = linalg.dim %view, constant : !linalg.view<...>
    affine.for %ix = %c0 to %D {
      ...
    }
    ```

    This form composes correctly with tiling and is also the proper way to emit loops from views that across function boundaries.
    The previous version that would extract the range_min/max/step was composing incorrectly with tiling (i.e. would shift by range_min both in the loop bounds and in the slice) and would not work across function boundaries.

    The relevant tests are updated and a new test `dot_view`---which lowers to loops from views passed as function parameters---is added.

    When additional context is available, the linalg.dim operations should be folded away but this is left for a future CL.

--

PiperOrigin-RevId: 249634712

mlir/include/mlir/Linalg/IR/LinalgOps.h
mlir/include/mlir/Linalg/Utils/Utils.h
mlir/lib/Linalg/Transforms/LowerToLoops.cpp
mlir/lib/Linalg/Utils/Utils.cpp
mlir/test/Linalg/loops.mlir

index 7a106cf..26a5cf3 100644 (file)
@@ -336,6 +336,10 @@ public:
                     ArrayRef<Value *> operands) {
     return impl->create(builder, loc, operands);
   }
+  Operation::operand_range getInputsAndOutputs() {
+    auto range = this->getOperation()->getOperands();
+    return {range.begin(), range.begin() + getNumInputsAndOutputs()};
+  }
 
 private:
   struct Concept {
index ebdfe97..eea92be 100644 (file)
@@ -89,8 +89,16 @@ Value *createOrReturnView(FuncBuilder *b, Location loc,
 enum class RangePart { Min = 0, Max, Step };
 Value *extractRangePart(Value *range, RangePart part);
 
+/// Returns the values obtained by applying `map` to the list of values.
+/// Performs simplifications and foldings where possible.
+SmallVector<Value *, 4> applyMapToValues(FuncBuilder *b, Location loc,
+                                         AffineMap map,
+                                         ArrayRef<Value *> values,
+                                         FunctionConstants &state);
+
 /// Returns the values obtained by applying `map` to the list of range parts
-/// extracted from `ranges`.
+/// extracted from `ranges`. Performs simplifications and foldings where
+/// possible.
 SmallVector<Value *, 4> applyMapToRangePart(FuncBuilder *b, Location loc,
                                             AffineMap map,
                                             ArrayRef<Value *> ranges,
index 4691d1f..b248578 100644 (file)
 #include "mlir/Linalg/Passes.h"
 #include "mlir/Linalg/Utils/Utils.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Support/STLExtras.h"
 
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/raw_ostream.h"
-
 using namespace mlir;
 using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
 using namespace mlir::linalg;
-using namespace llvm;
 
 // Creates a number of ranges equal to the number of results in `map`.
 // The returned ranges correspond to the loop ranges, in the proper order, for
 // which new loops will be created.
-static SmallVector<Value *, 4> makeLoopRanges(FuncBuilder *b, Location loc,
+static SmallVector<Value *, 4> emitLoopRanges(FuncBuilder *b, Location loc,
                                               AffineMap map,
-                                              ArrayRef<Value *> allOpRanges,
+                                              ArrayRef<Value *> allViewSizes,
                                               FunctionConstants &state) {
-  // Apply `map` to get mins/maxes/steps in loop order.
-  auto mins =
-      applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Min, state);
-  auto maxes =
-      applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Max, state);
-  auto steps =
-      applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Step, state);
-
+  // Apply `map` to get view sizes in loop order.
+  auto sizes = applyMapToValues(b, loc, map, allViewSizes, state);
   // Create a new range with the applied tile sizes.
   SmallVector<Value *, 4> res;
-  for (unsigned idx = 0, e = steps.size(); idx < e; ++idx)
-    res.push_back(b->create<RangeOp>(loc, mins[idx], maxes[idx], steps[idx]));
+  for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) {
+    res.push_back(b->create<RangeOp>(loc, state.getOrCreateIndex(0), sizes[idx],
+                                     state.getOrCreateIndex(1)));
+  }
+  return res;
+}
+
+// Returns the linearized list of all view dimensions in a linalgOp. Appliying
+// the inverse, concatenated loopToOperandRangeMaps to this list allows the
+// derivation of loop ranges for any linalgOp.
+static SmallVector<Value *, 8> getViewSizes(LinalgOp &linalgOp) {
+  SmallVector<Value *, 8> res;
+  using dim = ValueBuilder<linalg::DimOp>;
+  for (auto v : linalgOp.getInputsAndOutputs()) {
+    ViewType t = v->getType().cast<ViewType>();
+    for (unsigned i = 0; i < t.getRank(); ++i)
+      res.push_back(dim(v, i));
+  }
   return res;
 }
 
 static void emitLinalgOpAsLoops(LinalgOp &linalgOp, FunctionConstants &state) {
   FuncBuilder b(linalgOp.getOperation());
   ScopedContext scope(b, linalgOp.getOperation()->getLoc());
-  auto loopRanges = makeLoopRanges(
+  auto loopRanges = emitLoopRanges(
       scope.getBuilder(), scope.getLocation(),
       // The flattened loopToOperandRangesMaps is expected to be an invertible
       // permutation map (which is asserted in the inverse calculation).
       inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))),
-      getRanges(linalgOp.getOperation()), state);
+      getViewSizes(linalgOp), state);
 
   SmallVector<IndexHandle, 4> parallelIvs(linalgOp.getNumParallelLoops());
   SmallVector<IndexHandle, 4> reductionIvs(linalgOp.getNumReductionLoops());
index 00591c6..9d4e2c9 100644 (file)
@@ -167,16 +167,10 @@ static Value *emitOrFoldComposedAffineApply(FuncBuilder *b, Location loc,
   return b->create<AffineApplyOp>(loc, map, operands);
 }
 
-SmallVector<Value *, 4> mlir::applyMapToRangePart(FuncBuilder *b, Location loc,
-                                                  AffineMap map,
-                                                  ArrayRef<Value *> ranges,
-                                                  RangePart part,
-                                                  FunctionConstants &state) {
-  SmallVector<Value *, 4> rangeParts(ranges.size());
-
-  llvm::transform(ranges, rangeParts.begin(),
-                  [&](Value *range) { return extractRangePart(range, part); });
-
+SmallVector<Value *, 4> mlir::applyMapToValues(FuncBuilder *b, Location loc,
+                                               AffineMap map,
+                                               ArrayRef<Value *> values,
+                                               FunctionConstants &state) {
   SmallVector<Value *, 4> res;
   res.reserve(map.getNumResults());
   unsigned numDims = map.getNumDims();
@@ -185,12 +179,22 @@ SmallVector<Value *, 4> mlir::applyMapToRangePart(FuncBuilder *b, Location loc,
   // folding occurs eagerly. Otherwise, an affine.apply operation is emitted.
   for (auto expr : map.getResults()) {
     AffineMap map = AffineMap::get(numDims, 0, expr, {});
-    res.push_back(
-        emitOrFoldComposedAffineApply(b, loc, map, rangeParts, state));
+    res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, state));
   }
   return res;
 }
 
+SmallVector<Value *, 4> mlir::applyMapToRangePart(FuncBuilder *b, Location loc,
+                                                  AffineMap map,
+                                                  ArrayRef<Value *> ranges,
+                                                  RangePart part,
+                                                  FunctionConstants &state) {
+  SmallVector<Value *, 4> rangeParts(ranges.size());
+  llvm::transform(ranges, rangeParts.begin(),
+                  [&](Value *range) { return extractRangePart(range, part); });
+  return applyMapToValues(b, loc, map, rangeParts, state);
+}
+
 Value *FunctionConstants::getOrCreateIndex(int64_t v) {
   auto it = map.find(v);
   if (it != map.end())
index cdedde5..5e9246f 100644 (file)
@@ -18,14 +18,17 @@ func @matmul(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: inde
 //       CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
 //       CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
 //       CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
-//       CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%arg1) {
-//       CHECK:   affine.for %i1 = #[[ID]](%c0) to #[[ID]](%arg2) {
-//       CHECK:     affine.for %i2 = #[[ID]](%c0) to #[[ID]](%arg3) {
+//       CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?x?xf32>
+//       CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view<?x?xf32>
+//       CHECK: %[[N:.*]] = linalg.dim %[[B]], 1 : !linalg.view<?x?xf32>
+//       CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[M]]) {
+//       CHECK:   affine.for %i1 = #[[ID]](%c0) to #[[ID]](%[[N]]) {
+//       CHECK:     affine.for %i2 = #[[ID]](%c0) to #[[ID]](%[[K]]) {
 //   CHECK-DAG:       %[[a:.*]] = linalg.load %[[A]][%i0, %i2] : !linalg.view<?x?xf32>
 //   CHECK-DAG:       %[[b:.*]] = linalg.load %[[B]][%i2, %i1] : !linalg.view<?x?xf32>
-//       CHECK:       %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
-//       CHECK:       %[[c:.*]] = linalg.load %[[C]][%i0, %i1] : !linalg.view<?x?xf32>
-//       CHECK:       %[[res:.*]] = addf %[[c]], %[[inc]] : f32
+//   CHECK-DAG:       %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
+//   CHECK-DAG:       %[[c:.*]] = linalg.load %[[C]][%i0, %i1] : !linalg.view<?x?xf32>
+//   CHECK-DAG:       %[[res:.*]] = addf %[[c]], %[[inc]] : f32
 //       CHECK:       linalg.store %[[res]], %[[C]][%i0, %i1] : !linalg.view<?x?xf32>
 
 func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
@@ -43,13 +46,15 @@ func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: inde
 //       CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
 //       CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
 //       CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
-//       CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%arg1) {
-//       CHECK:   affine.for %i1 = #[[ID]](%c0) to #[[ID]](%arg2) {
+//       CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?x?xf32>
+//       CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view<?x?xf32>
+//       CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[M]]) {
+//       CHECK:   affine.for %i1 = #[[ID]](%c0) to #[[ID]](%[[K]]) {
 //   CHECK-DAG:     %[[a:.*]] = linalg.load %[[A]][%i0, %i1] : !linalg.view<?x?xf32>
 //   CHECK-DAG:     %[[b:.*]] = linalg.load %[[B]][%i1] : !linalg.view<?xf32>
-//       CHECK:     %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
-//       CHECK:     %[[c:.*]] = linalg.load %[[C]][%i0] : !linalg.view<?xf32>
-//       CHECK:     %[[res:.*]] = addf %[[c]], %[[inc]] : f32
+//   CHECK-DAG:     %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
+//   CHECK-DAG:     %[[c:.*]] = linalg.load %[[C]][%i0] : !linalg.view<?xf32>
+//   CHECK-DAG:     %[[res:.*]] = addf %[[c]], %[[inc]] : f32
 //       CHECK:     linalg.store %[[res]], %[[C]][%i0] : !linalg.view<?xf32>
 
 func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
@@ -66,10 +71,25 @@ func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index)
 //       CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
 //       CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
 //       CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.view<f32>
-//       CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%arg1) {
+//       CHECK: %[[K:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?xf32>
+//       CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[K]]) {
 //   CHECK-DAG:   %[[a:.*]] = linalg.load %[[A]][%i0] : !linalg.view<?xf32>
 //   CHECK-DAG:   %[[b:.*]] = linalg.load %[[B]][%i0] : !linalg.view<?xf32>
-//       CHECK:   %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
-//       CHECK:   %[[c:.*]] = linalg.load %[[C]][] : !linalg.view<f32>
-//       CHECK:   %[[res:.*]] = addf %[[c]], %[[inc]] : f32
-//       CHECK:   linalg.store %[[res]], %[[C]][] : !linalg.view<f32>
\ No newline at end of file
+//   CHECK-DAG:   %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
+//   CHECK-DAG:   %[[c:.*]] = linalg.load %[[C]][] : !linalg.view<f32>
+//   CHECK-DAG:   %[[res:.*]] = addf %[[c]], %[[inc]] : f32
+//       CHECK:   linalg.store %[[res]], %[[C]][] : !linalg.view<f32>
+
+func @dot_view(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<f32>) {
+  linalg.dot(%arg0, %arg1, %arg2) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
+  return
+}
+// CHECK-LABEL: func @dot_view(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<f32>) {
+//       CHECK: %[[K:.*]] = linalg.dim %arg0, 0 : !linalg.view<?xf32>
+//       CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[K]]) {
+//   CHECK-DAG:   %[[a:.*]] = linalg.load %arg0[%i0] : !linalg.view<?xf32>
+//   CHECK-DAG:   %[[b:.*]] = linalg.load %arg1[%i0] : !linalg.view<?xf32>
+//   CHECK-DAG:   %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
+//   CHECK-DAG:   %[[c:.*]] = linalg.load %arg2[] : !linalg.view<f32>
+//   CHECK-DAG:   %[[res:.*]] = addf %[[c]], %[[inc]] : f32
+//       CHECK:   linalg.store %[[res]], %arg2[] : !linalg.view<f32>