if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
return rewriter.notifyMatchFailure(
warpOp, "Reduction vector dimension must match was size.");
- // Only f32 and i32 element types are supported.
+ // Only f32, i32, f16, i8 element types are supported.
if (!reductionOp.getType().isF32() &&
- !reductionOp.getType().isSignlessInteger(32))
+ !reductionOp.getType().isSignlessInteger(32) &&
+ !reductionOp.getType().isF16() && !reductionOp.getType().isInteger(8))
return rewriter.notifyMatchFailure(
- warpOp,
- "Reduction distribution currently only supports 32bits types.");
+ warpOp, "Reduction distribution currently only supports 32bits, f16, "
+ "and i8 types.");
int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
// Return vector that will be reduced from the WarpExecuteOnLane0Op.
rewriter, warpOp, yieldValues, retTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
+ // Obtain data to reduce for a single lane.
Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
- // First reduce on a single thread.
- Value perLaneReduction = rewriter.create<vector::ReductionOp>(
- reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
- // Then distribute across threads.
+ // Distribute and reduce across threads.
Value fullReduce =
- distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
+ distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
reductionOp.getKind(), newWarpOp.getWarpSize());
if (reductionOp.getAcc()) {
fullReduce = vector::makeArithReduction(
static Value warpReduction(Location loc, OpBuilder &builder, Value input,
CombiningKind kind, uint32_t size) {
- Value laneVal = input;
+ // First reduce on a single thread to get per lane reduction value.
+ Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
// Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < size; i <<= 1) {
Value shuffled = builder