[Linalg] Add tiling for IndexedGenericOp with a region.
authorAlexander Belyaev <pifon@google.com>
Wed, 11 Dec 2019 10:56:06 +0000 (02:56 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 11 Dec 2019 10:56:40 +0000 (02:56 -0800)
PiperOrigin-RevId: 284949355

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Dialect/Linalg/tile_indexed_generic.mlir [new file with mode: 0644]

index 38ebf5d..31291b6 100644 (file)
@@ -58,14 +58,18 @@ static bool isZero(Value *v) {
          cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
 }
 
+using LoopIndexToRangeIndexMap = DenseMap<int, int>;
+
 // Creates a number of ranges equal to the number of non-zero in `tileSizes`.
 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
 // one entry per surrounding loop. It uses zero as the convention that a
 // particular loop is not tiled. This convention simplifies implementations by
 // avoiding affine map manipulations.
 // The returned ranges correspond to the loop ranges, in the proper order, that
-// are tiled and for which new loops will be created.
-static SmallVector<SubViewOp::Range, 4>
+// are tiled and for which new loops will be created. Also the function returns
+// a map from loop indices of the LinalgOp to the corresponding non-empty range
+// indices of newly created loops.
+static std::tuple<SmallVector<SubViewOp::Range, 4>, LoopIndexToRangeIndexMap>
 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
                     ArrayRef<Value *> allViewSizes,
                     ArrayRef<Value *> allTileSizes, OperationFolder *folder) {
@@ -75,11 +79,15 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
   SmallVector<Value *, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
 
   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
-  for (int idx = tileSizes.size() - 1; idx >= 0; --idx) {
-    if (isZero(tileSizes[idx])) {
-      viewSizes.erase(viewSizes.begin() + idx);
-      tileSizes.erase(tileSizes.begin() + idx);
+  LoopIndexToRangeIndexMap loopIndexToRangeIndex;
+  for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
+    if (isZero(tileSizes[idx - zerosCount])) {
+      viewSizes.erase(viewSizes.begin() + idx - zerosCount);
+      tileSizes.erase(tileSizes.begin() + idx - zerosCount);
+      ++zerosCount;
+      continue;
     }
+    loopIndexToRangeIndex[idx] = idx - zerosCount;
   }
 
   // Create a new range with the applied tile sizes.
@@ -88,10 +96,11 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
     res.push_back(SubViewOp::Range{constant_index(folder, 0), viewSizes[idx],
                                    tileSizes[idx]});
   }
-  return res;
+  return std::make_tuple(res, loopIndexToRangeIndex);
 }
 
 namespace {
+
 // Helper visitor to determine whether an AffineExpr is tiled.
 // This is achieved by traversing every AffineDimExpr with position `pos` and
 // checking whether the corresponding `tileSizes[pos]` is non-zero.
@@ -117,8 +126,99 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
   bool isTiled;
   ArrayRef<Value *> tileSizes;
 };
+
 } // namespace
 
+// IndexedGenericOp explicitly uses induction variables in the loop body. The
+// values of the indices that are used in the loop body for any given access of
+// input/output memref before `subview` op was applied should be invariant with
+// respect to tiling.
+//
+// Therefore, if the operation is tiled, we have to transform the indices
+// accordingly, i.e. offset them by the values of the corresponding induction
+// variables that are captured implicitly in the body of the op.
+//
+// Example. `linalg.indexed_generic` before tiling:
+//
+// #id_2d = (i, j) -> (i, j)
+// #pointwise_2d_trait = {
+//   indexing_maps = [#id_2d, #id_2d],
+//   iterator_types = ["parallel", "parallel"],
+//   n_views = [1, 1]
+// }
+// linalg.indexed_generic #pointwise_2d_trait %operand, %result {
+//   ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
+//     <some operations that use %i, %j>
+// }: memref<50x100xf32>, memref<50x100xf32>
+//
+// After tiling pass with tiles sizes 10 and 25:
+//
+// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
+//
+// %c1 = constant 1 : index
+// %c0 = constant 0 : index
+// %c25 = constant 25 : index
+// %c10 = constant 10 : index
+// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
+// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
+// loop.for %k = %c0 to operand_dim_0 step %c10 {
+//   loop.for %l = %c0 to operand_dim_1 step %c25 {
+//     %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
+//       : memref<50x100xf32> to memref<?x?xf32, #strided>
+//     %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
+//       : memref<50x100xf32> to memref<?x?xf32, #strided>
+//     linalg.indexed_generic pointwise_2d_trait %4, %5 {
+//     ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
+//       // Indices `k` and `l` are implicitly captured in the body.
+//       %transformed_i = addi %i, %k : index // index `i` is offset by %k
+//       %transformed_j = addi %j, %l : index // index `j` is offset by %l
+//       // Every use of %i, %j is replaced with %transformed_i, %transformed_j
+//       <some operations that use %transformed_i, %transformed_j>
+//     }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
+//   }
+// }
+//
+// TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices
+// does not lead to losing information.
+void transformIndexedGenericOpIndices(
+    OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs,
+    const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
+  auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
+  if (!indexedGenericOp)
+    return;
+
+  // `linalg.indexed_generic` comes in two flavours. One has a region with a
+  // single block that defines the loop body. The other has a `fun` attribute
+  // that refers to an existing function symbol. The `fun` function call will be
+  // inserted in the loop body in that case.
+  //
+  // TODO(pifon): Add support for `linalg.indexed_generic` with `fun` attribute.
+  auto &region = indexedGenericOp.region();
+  if (region.empty()) {
+    indexedGenericOp.emitError("op expected a region");
+    return;
+  }
+  auto &block = region.getBlocks().front();
+
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPointToStart(&block);
+  for (unsigned i = 0; i < indexedGenericOp.getNumLoops(); ++i) {
+    auto rangeIndex = loopIndexToRangeIndex.find(i);
+    if (rangeIndex == loopIndexToRangeIndex.end())
+      continue;
+    Value *oldIndex = block.getArgument(i);
+    // Offset the index argument `i` by the value of the corresponding induction
+    // variable and replace all uses of the previous value.
+    Value *newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
+                                       pivs[rangeIndex->second]->getValue());
+    for (auto &use : oldIndex->getUses()) {
+      if (use.getOwner() == newIndex->getDefiningOp())
+        continue;
+      use.set(newIndex);
+    }
+  }
+}
+
 static bool isTiled(AffineExpr expr, ArrayRef<Value *> tileSizes) {
   if (!expr)
     return false;
@@ -244,7 +344,10 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
   auto viewSizesToLoopsMap =
       inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op)));
   assert(viewSizesToLoopsMap && "expected invertible map");
-  auto loopRanges =
+
+  SmallVector<SubViewOp::Range, 4> loopRanges;
+  LoopIndexToRangeIndexMap loopIndexToRangeIndex;
+  std::tie(loopRanges, loopIndexToRangeIndex) =
       makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
                           viewSizes, tileSizes, folder);
   if (!permutation.empty())
@@ -274,7 +377,10 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
     res = op.clone(b, loc, views);
   });
 
-  // 4. Gather the newly created loops and return them with the new op.
+  // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
+  transformIndexedGenericOpIndices(b, res, pivs, loopIndexToRangeIndex);
+
+  // 5. Gather the newly created loops and return them with the new op.
   SmallVector<ForOp, 8> loops;
   loops.reserve(ivs.size());
   for (auto iv : ivs)
diff --git a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir
new file mode 100644 (file)
index 0000000..c17a463
--- /dev/null
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=10,25 | FileCheck %s -check-prefix=TILE-10n25
+// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=25,0 | FileCheck %s -check-prefix=TILE-25n0
+// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=0,25 | FileCheck %s -check-prefix=TILE-0n25
+
+#id_1d = (i) -> (i)
+#pointwise_1d_trait = {
+  indexing_maps = [#id_1d, #id_1d],
+  iterator_types = ["parallel"],
+  n_views = [1, 1]
+}
+func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) {
+  linalg.indexed_generic #pointwise_1d_trait %operand, %result {
+    ^bb0(%i: index, %operand_in: f32, %result_in: f32):
+      %i_int = index_cast %i: index to i32
+      %i_float = sitofp %i_int : i32 to f32
+      %out = addf %operand_in, %i_float : f32
+      linalg.yield %out : f32
+  }: memref<50xf32>, memref<50xf32>
+  return
+}
+// TILE-10n25-LABEL: func @indexed_generic_vector
+// TILE-10n25: %[[C10:.*]] = constant 10 : index
+// TILE-10n25: loop.for %[[J:.*]] = {{.*}} step %[[C10]]
+// TILE-10n25:   linalg.indexed_generic
+// TILE-10n25:   ^bb0(%[[I:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32)
+// TILE-10n25:     %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index
+// TILE-10n25:     %[[NEW_I_INT:.*]] = index_cast %[[NEW_I]] : index to i32
+// TILE-10n25:     %[[NEW_I_FLOAT:.*]] = sitofp %[[NEW_I_INT]] : i32 to f32
+// TILE-10n25:     %[[OUT:.*]] = addf %[[IN]], %[[NEW_I_FLOAT]] : f32
+
+// TILE-25n0-LABEL: func @indexed_generic_vector
+// TILE-25n0: %[[C25:.*]] = constant 25 : index
+// TILE-25n0: loop.for %[[J:.*]] = {{.*}} step %[[C25]]
+// TILE-25n0:   linalg.indexed_generic
+// TILE-25n0:   ^bb0(%[[I:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32)
+// TILE-25n0:     %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index
+// TILE-25n0:     %[[NEW_I_INT:.*]] = index_cast %[[NEW_I]] : index to i32
+// TILE-25n0:     %[[NEW_I_FLOAT:.*]] = sitofp %[[NEW_I_INT]] : i32 to f32
+// TILE-25n0:     %[[OUT:.*]] = addf %[[IN]], %[[NEW_I_FLOAT]] : f32
+
+// TILE-0n25-LABEL: func @indexed_generic_vector
+// TILE-0n25-NOT: loop.for %[[J:.*]] = {{.*}} step %[[C25]]
+// TILE-0n25: linalg.indexed_generic
+
+#combined_indices_trait = {
+  indexing_maps = [
+    (i, j) -> (j, i + j),
+    (i, j) -> (i, j)
+  ],
+  iterator_types = ["parallel", "parallel"],
+  n_views = [1, 1]
+}
+func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) {
+  linalg.indexed_generic #combined_indices_trait %operand, %result {
+    ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
+      %i_int = index_cast %i: index to i32
+      %i_float = sitofp %i_int : i32 to f32
+      %j_int = index_cast %j: index to i32
+      %j_float = sitofp %j_int : i32 to f32
+      %out = addf %i_float, %j_float : f32
+      linalg.yield %out : f32
+  }: memref<50x100xf32>, memref<50x100xf32>
+  return
+}
+// TILE-10n25-LABEL: func @indexed_generic_matrix
+// TILE-10n25: %[[C25:.*]] = constant 25 : index
+// TILE-10n25: %[[C10:.*]] = constant 10 : index
+// TILE-10n25: loop.for %[[K:.*]] = {{.*}} step %[[C10]]
+// TILE-10n25:   loop.for %[[L:.*]] = {{.*}} step %[[C25]]
+// TILE-10n25:     linalg.indexed_generic
+// TILE-10n25:     ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// TILE-10n25:       %[[NEW_I:.*]] = addi %[[I]], %[[K]] : index
+// TILE-10n25:       %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index
+// TILE-10n25:       %[[NEW_INT_I:.*]] = index_cast %[[NEW_I]] : index to i32
+// TILE-10n25:       %[[NEW_FLOAT_I:.*]] = sitofp %[[NEW_INT_I]] : i32 to f32
+// TILE-10n25:       %[[NEW_INT_J:.*]] = index_cast %[[NEW_J]] : index to i32
+// TILE-10n25:       %[[NEW_FLOAT_J:.*]] = sitofp %[[NEW_INT_J]] : i32 to f32
+// TILE-10n25:       %[[OUT:.*]] = addf %[[NEW_FLOAT_I]], %[[NEW_FLOAT_J]] : f32
+
+// TILE-25n0-LABEL: func @indexed_generic_matrix
+// TILE-25n0: %[[C25:.*]] = constant 25 : index
+// TILE-25n0: loop.for %[[L:.*]] = {{.*}} step %[[C25]]
+// TILE-25n0:   linalg.indexed_generic
+// TILE-25n0:   ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// TILE-25n0:     %[[NEW_I:.*]] = addi %[[I]], %[[L]] : index
+// TILE-25n0:     %[[NEW_INT_I:.*]] = index_cast %[[NEW_I]] : index to i32
+// TILE-25n0:     %[[NEW_FLOAT_I:.*]] = sitofp %[[NEW_INT_I]] : i32 to f32
+// TILE-25n0:     %[[INT_J:.*]] = index_cast %[[J]] : index to i32
+// TILE-25n0:     %[[FLOAT_J:.*]] = sitofp %[[INT_J]] : i32 to f32
+// TILE-25n0:     %[[OUT:.*]] = addf %[[NEW_FLOAT_I]], %[[FLOAT_J]] : f32
+
+// TILE-0n25-LABEL: func @indexed_generic_matrix
+// TILE-0n25: %[[C25:.*]] = constant 25 : index
+// TILE-0n25: loop.for %[[L:.*]] = {{.*}} step %[[C25]]
+// TILE-0n25:   linalg.indexed_generic
+// TILE-0n25:   ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// TILE-0n25:     %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index
+// TILE-0n25:     %[[INT_I:.*]] = index_cast %[[I]] : index to i32
+// TILE-0n25:     %[[FLOAT_I:.*]] = sitofp %[[INT_I]] : i32 to f32
+// TILE-0n25:     %[[NEW_INT_J:.*]] = index_cast %[[NEW_J]] : index to i32
+// TILE-0n25:     %[[NEW_FLOAT_J:.*]] = sitofp %[[NEW_INT_J]] : i32 to f32
+// TILE-0n25:     %[[OUT:.*]] = addf %[[FLOAT_I]], %[[NEW_FLOAT_J]] : f32