Do load and store to verify that we process each element of the iteration space once.
Reviewed By: cota
Differential Revision: https://reviews.llvm.org/D115152
Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize);
// The last one-dimensional index in the block defined by the `blockIndex`:
+ // blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1
Value blockEnd0 = b.create<arith::AddIOp>(blockFirstIndex, blockSize);
Value blockEnd1 = b.create<arith::MinSIOp>(blockEnd0, tripCount);
Value blockLastIndex = b.create<arith::SubIOp>(blockEnd1, c1);
// iteration coordinate using parallel operation bounds and step:
//
// computeBlockInductionVars[loopIdx] =
- // lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopDdx]
+ // lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopIdx]
SmallVector<Value> computeBlockInductionVars(op.getNumLoops());
// We need to know if we are in the first or last iteration of the
// Keep building loop nest.
if (loopIdx < op.getNumLoops() - 1) {
- // Select nested loop lower/upper bounds depending on out position in
+ // Select nested loop lower/upper bounds depending on our position in
// the multi-dimensional iteration space.
auto lb = nb.create<SelectOp>(isBlockFirstCoord[loopIdx],
blockFirstCoord[loopIdx + 1], c0);
%A = memref.alloc() : memref<9xf32>
%U = memref.cast %A : memref<9xf32> to memref<*xf32>
+ // Initialize memref with zeros because we do load and store to in every test
+ // to verify that we process each element of the iteration space once.
+ scf.parallel (%i) = (%lb) to (%ub) step (%c1) {
+ memref.store %c0, %A[%i] : memref<9xf32>
+ }
+
// 1. %i = (0) to (9) step (1)
scf.parallel (%i) = (%lb) to (%ub) step (%c1) {
%0 = arith.index_cast %i : index to i32
%1 = arith.sitofp %0 : i32 to f32
- memref.store %1, %A[%i] : memref<9xf32>
+ %2 = memref.load %A[%i] : memref<9xf32>
+ %3 = arith.addf %1, %2 : f32
+ memref.store %3, %A[%i] : memref<9xf32>
}
// CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8]
call @print_memref_f32(%U): (memref<*xf32>) -> ()
scf.parallel (%i) = (%lb) to (%ub) step (%c2) {
%0 = arith.index_cast %i : index to i32
%1 = arith.sitofp %0 : i32 to f32
- memref.store %1, %A[%i] : memref<9xf32>
+ %2 = memref.load %A[%i] : memref<9xf32>
+ %3 = arith.addf %1, %2 : f32
+ memref.store %3, %A[%i] : memref<9xf32>
}
// CHECK: [0, 0, 2, 0, 4, 0, 6, 0, 8]
call @print_memref_f32(%U): (memref<*xf32>) -> ()
%1 = arith.sitofp %0 : i32 to f32
%2 = arith.constant 20 : index
%3 = arith.addi %i, %2 : index
- memref.store %1, %A[%3] : memref<9xf32>
+ %4 = memref.load %A[%3] : memref<9xf32>
+ %5 = arith.addf %1, %4 : f32
+ memref.store %5, %A[%3] : memref<9xf32>
}
// CHECK: [-20, 0, 0, -17, 0, 0, -14, 0, 0]
call @print_memref_f32(%U): (memref<*xf32>) -> ()
%A = memref.alloc() : memref<8x8xf32>
%U = memref.cast %A : memref<8x8xf32> to memref<*xf32>
+ // Initialize memref with zeros because we do load and store to in every test
+ // to verify that we process each element of the iteration space once.
+ scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c1) {
+ memref.store %c0, %A[%i, %j] : memref<8x8xf32>
+ }
+
// 1. (%i, %i) = (0, 8) to (8, 8) step (1, 1)
scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c1) {
%0 = arith.muli %i, %c8 : index
%1 = arith.addi %j, %0 : index
%2 = arith.index_cast %1 : index to i32
%3 = arith.sitofp %2 : i32 to f32
- memref.store %3, %A[%i, %j] : memref<8x8xf32>
+ %4 = memref.load %A[%i, %j] : memref<8x8xf32>
+ %5 = arith.addf %3, %4 : f32
+ memref.store %5, %A[%i, %j] : memref<8x8xf32>
}
// CHECK: [0, 1, 2, 3, 4, 5, 6, 7]
%1 = arith.addi %j, %0 : index
%2 = arith.index_cast %1 : index to i32
%3 = arith.sitofp %2 : i32 to f32
- memref.store %3, %A[%i, %j] : memref<8x8xf32>
+ %4 = memref.load %A[%i, %j] : memref<8x8xf32>
+ %5 = arith.addf %3, %4 : f32
+ memref.store %5, %A[%i, %j] : memref<8x8xf32>
}
// CHECK: [0, 1, 2, 3, 4, 5, 6, 7]
%1 = arith.addi %j, %0 : index
%2 = arith.index_cast %1 : index to i32
%3 = arith.sitofp %2 : i32 to f32
- memref.store %3, %A[%i, %j] : memref<8x8xf32>
+ %4 = memref.load %A[%i, %j] : memref<8x8xf32>
+ %5 = arith.addf %3, %4 : f32
+ memref.store %5, %A[%i, %j] : memref<8x8xf32>
}
// CHECK: [0, 0, 2, 0, 4, 0, 6, 0]