[MLIR] Fix parallel loop tiling.
authorStephan Herhut <herhut@google.com>
Wed, 17 Jun 2020 14:22:07 +0000 (16:22 +0200)
committerStephan Herhut <herhut@google.com>
Wed, 17 Jun 2020 21:30:13 +0000 (23:30 +0200)
Summary:
Parallel loop tiling did not properly compute the updated loop
indices when tiling, which lead to wrong results.

Differential Revision: https://reviews.llvm.org/D82013

mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
mlir/test/Dialect/SCF/parallel-loop-tiling.mlir

index 8e84566..4046913 100644 (file)
@@ -30,9 +30,13 @@ using namespace mlir::scf;
 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
 ///                                             step (%arg4*tileSize[0],
 ///                                                   %arg5*tileSize[1])
-///     scf.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%j0)
-///                                           min(tileSize[1], %arg3-%j1))
+///     scf.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%i0)
+///                                           min(tileSize[1], %arg3-%i1))
 ///                                        step (%arg4, %arg5)
+///
+/// where the uses of %i0 and %i1 in the loop body are replaced by
+/// %i0 + j0 and %i1 + %j1.
+//
 /// The old loop is replaced with the new one.
 void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
   OpBuilder b(op);
@@ -85,6 +89,18 @@ void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
 
   // Steal the body of the old parallel loop and erase it.
   innerLoop.region().takeBody(op.region());
+
+  // Insert computation for new index vectors and replace uses.
+  b.setInsertionPointToStart(innerLoop.getBody());
+  for (auto ivs :
+       llvm::zip(innerLoop.getInductionVars(), outerLoop.getInductionVars())) {
+    Value inner_index = std::get<0>(ivs);
+    AddIOp newIndex =
+        b.create<AddIOp>(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs));
+    inner_index.replaceAllUsesExcept(
+        newIndex, SmallPtrSet<Operation *, 1>{newIndex.getOperation()});
+  }
+
   op.erase();
 }
 
index 1491243..f124162 100644 (file)
@@ -25,10 +25,12 @@ func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
 // CHECK:             [[VAL_17:%.*]] = affine.min #map0([[VAL_11]], [[VAL_2]], [[VAL_15]])
 // CHECK:             [[VAL_18:%.*]] = affine.min #map0([[VAL_12]], [[VAL_3]], [[VAL_16]])
 // CHECK:             scf.parallel ([[VAL_19:%.*]], [[VAL_20:%.*]]) = ([[VAL_10]], [[VAL_10]]) to ([[VAL_17]], [[VAL_18]]) step ([[VAL_4]], [[VAL_5]]) {
-// CHECK:               [[VAL_21:%.*]] = load [[VAL_7]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
-// CHECK:               [[VAL_22:%.*]] = load [[VAL_8]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
-// CHECK:               [[VAL_23:%.*]] = addf [[VAL_21]], [[VAL_22]] : f32
-// CHECK:               store [[VAL_23]], [[VAL_9]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
+// CHECK:               [[VAL_21:%.*]] = addi [[VAL_19]], [[VAL_15]] : index
+// CHECK:               [[VAL_22:%.*]] = addi [[VAL_20]], [[VAL_16]] : index
+// CHECK:               [[VAL_23:%.*]] = load [[VAL_7]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref<?x?xf32>
+// CHECK:               [[VAL_24:%.*]] = load [[VAL_8]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref<?x?xf32>
+// CHECK:               [[VAL_25:%.*]] = addf [[VAL_23]], [[VAL_24]] : f32
+// CHECK:               store [[VAL_25]], [[VAL_9]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref<?x?xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           return