/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
/// step (%arg4*tileSize[0],
/// %arg5*tileSize[1])
-/// scf.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%j0)
-/// min(tileSize[1], %arg3-%j1))
+/// scf.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%i0)
+/// min(tileSize[1], %arg3-%i1))
/// step (%arg4, %arg5)
+///
+/// where the uses of %i0 and %i1 in the loop body are replaced by
+/// %i0 + j0 and %i1 + %j1.
+//
/// The old loop is replaced with the new one.
void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
OpBuilder b(op);
// Steal the body of the old parallel loop and erase it.
innerLoop.region().takeBody(op.region());
+
+ // Insert computation for new index vectors and replace uses.
+ b.setInsertionPointToStart(innerLoop.getBody());
+ for (auto ivs :
+ llvm::zip(innerLoop.getInductionVars(), outerLoop.getInductionVars())) {
+ Value inner_index = std::get<0>(ivs);
+ AddIOp newIndex =
+ b.create<AddIOp>(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs));
+ inner_index.replaceAllUsesExcept(
+ newIndex, SmallPtrSet<Operation *, 1>{newIndex.getOperation()});
+ }
+
op.erase();
}
// CHECK: [[VAL_17:%.*]] = affine.min #map0([[VAL_11]], [[VAL_2]], [[VAL_15]])
// CHECK: [[VAL_18:%.*]] = affine.min #map0([[VAL_12]], [[VAL_3]], [[VAL_16]])
// CHECK: scf.parallel ([[VAL_19:%.*]], [[VAL_20:%.*]]) = ([[VAL_10]], [[VAL_10]]) to ([[VAL_17]], [[VAL_18]]) step ([[VAL_4]], [[VAL_5]]) {
-// CHECK: [[VAL_21:%.*]] = load [[VAL_7]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
-// CHECK: [[VAL_22:%.*]] = load [[VAL_8]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
-// CHECK: [[VAL_23:%.*]] = addf [[VAL_21]], [[VAL_22]] : f32
-// CHECK: store [[VAL_23]], [[VAL_9]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
+// CHECK: [[VAL_21:%.*]] = addi [[VAL_19]], [[VAL_15]] : index
+// CHECK: [[VAL_22:%.*]] = addi [[VAL_20]], [[VAL_16]] : index
+// CHECK: [[VAL_23:%.*]] = load [[VAL_7]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref<?x?xf32>
+// CHECK: [[VAL_24:%.*]] = load [[VAL_8]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref<?x?xf32>
+// CHECK: [[VAL_25:%.*]] = addf [[VAL_23]], [[VAL_24]] : f32
+// CHECK: store [[VAL_25]], [[VAL_9]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref<?x?xf32>
// CHECK: }
// CHECK: }
// CHECK: return