struct SCFFuseProducerOfSliceResult {
OpResult origProducer; // Original untiled producer.
Value tiledAndFusedProducer; // Tile and fused producer value.
+ SmallVector<Operation *> tiledOps;
};
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
}
}
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
- tileAndFuseResult->tiledValues[0]};
+ tileAndFuseResult->tiledValues[0],
+ tileAndFuseResult->tiledOps};
}
/// Reconstruct the fused producer from within the tiled-and-fused code.
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<scf::ForOp> loops) {
- auto [fusableProducer, fusedProducerValue] = fusedProducerInfo;
+ auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
+ fusedProducerInfo;
SmallVector<Value> initValues;
FailureOr<Value> initValue = tensor::getOrCreateDestination(
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
resultOffsets, resultSizes, loops);
}
- if (auto dstStyleProducer =
- fusedProducerValue.getDefiningOp<DestinationStyleOpInterface>()) {
+ for (auto tileAndFusedOp : tileAndFusedOps) {
+ auto dstStyleProducer =
+ dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
+ if (!dstStyleProducer)
+ continue;
Value dstValue =
dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
->get();