static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
Value value) {
if (hasRetVal) {
+ assert(value && "Expected non-empty value");
b.create<scf::YieldOp>(loc, value);
} else {
b.create<scf::YieldOp>(loc);
newXferOp->setAttr(kPassLabel, b.getUnitAttr());
}
+/// Return true if this transfer op operates on a source tensor.
+template <typename OpTy>
+static bool isTensorOp(OpTy xferOp) {
+ if (xferOp.getShapedType().template isa<RankedTensorType>()) {
+ if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
+ // TransferWriteOps on tensors have a result.
+ assert(xferOp->getNumResults() > 0);
+ }
+ return true;
+ }
+ return false;
+}
+
namespace lowering_n_d {
/// Helper data structure for data and mask buffers.
/// Note: The `mask` operand is set in TransferOpConversion.
static TransferReadOp rewriteOp(OpBuilder &b,
VectorTransferToSCFOptions options,
- TransferReadOp xferOp, Value buffer,
- Value iv) {
+ TransferReadOp xferOp, Value buffer, Value iv,
+ ValueRange /*loopState*/) {
SmallVector<Value, 8> storeIndices;
getBufferIndices(xferOp, storeIndices);
storeIndices.push_back(iv);
/// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
/// padding value to the temporary buffer.
- static void handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
- Value buffer, Value iv) {
+ static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
+ Value buffer, Value iv,
+ ValueRange /*loopState*/) {
SmallVector<Value, 8> storeIndices;
getBufferIndices(xferOp, storeIndices);
storeIndices.push_back(iv);
auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
+
+ return Value();
}
/// Cleanup after rewriting the op.
- static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) {
+ static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
+ scf::ForOp /*forOp*/) {
rewriter.eraseOp(getStoreOp(xferOp));
rewriter.eraseOp(xferOp);
}
+
+ /// Return the initial loop state for the generated scf.for loop.
+ static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
};
/// Codegen strategy for vector TransferWriteOp.
static TransferWriteOp rewriteOp(OpBuilder &b,
VectorTransferToSCFOptions options,
TransferWriteOp xferOp, Value buffer,
- Value iv) {
+ Value iv, ValueRange loopState) {
SmallVector<Value, 8> loadIndices;
getBufferIndices(xferOp, loadIndices);
loadIndices.push_back(iv);
Location loc = xferOp.getLoc();
auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
+ auto source = loopState.empty() ? xferOp.source() : loopState[0];
+ Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
auto newXferOp = b.create<vector::TransferWriteOp>(
- loc, Type(), vec, xferOp.source(), xferIndices,
+ loc, type, vec, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);
}
/// Handle out-of-bounds accesses on the to-be-unpacked dimension.
- static void handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
- Value buffer, Value iv) {}
+ static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
+ Value buffer, Value iv,
+ ValueRange loopState) {
+ return isTensorOp(xferOp) ? loopState[0] : Value();
+ }
/// Cleanup after rewriting the op.
- static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
- rewriter.eraseOp(xferOp);
+ static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
+ scf::ForOp forOp) {
+ if (isTensorOp(xferOp)) {
+ assert(forOp->getNumResults() == 1 && "Expected one for loop result");
+ rewriter.replaceOp(xferOp, forOp->getResult(0));
+ } else {
+ rewriter.eraseOp(xferOp);
+ }
+ }
+
+ /// Return the initial loop state for the generated scf.for loop.
+ static Value initialLoopState(TransferWriteOp xferOp) {
+ return isTensorOp(xferOp) ? xferOp.source() : Value();
}
};
return failure();
if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
- if (xferOp.getShapedType().template isa<RankedTensorType>())
+ if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
// Transfer ops that modify the element type are not supported atm.
if (xferOp.getVectorType().getElementType() !=
/// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
/// out-of-bounds, generate an if-check and handle both cases separately.
/// 3. Clean up according to the corresponding Strategy<OpTy>.
+///
+/// Note: If the transfer op is a TransferWriteOp and operates on a tensor
+/// source (as opposed to a memref source), then each iteration of the generated
+/// scf.for loop yields the new tensor value. E.g.:
+/// ```
+/// %result = scf.for i = 0 to 5 {
+/// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
+/// %1 = vector.transfer_write %0, %source[...]
+/// : vector<4x3xf32>, tensor<5x4x3xf32>
+/// scf.yield %1 : tensor<5x4x3xf32>
+/// }
+/// ```
template <typename OpTy>
struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
auto ub = locB.create<ConstantIndexOp>(
castedDataType.getDimSize(castedDataType.getRank() - 1));
auto step = locB.create<ConstantIndexOp>(1);
+ // TransferWriteOps that operate on tensors return the modified tensor and
+ // require a loop state.
+ auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
// Generate for loop.
- locB.create<scf::ForOp>(
- lb, ub, step, ValueRange(),
- [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) {
- generateInBoundsCheck(
+ auto result = locB.create<scf::ForOp>(
+ lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
+ Type stateType = loopState.empty() ? Type() : loopState[0].getType();
+
+ auto result = generateInBoundsCheck(
b, xferOp, iv, unpackedDim(xferOp),
+ stateType ? TypeRange(stateType) : TypeRange(),
/*inBoundsCase=*/
[&](OpBuilder &b, Location loc) {
// Create new transfer op.
OpTy newXfer = Strategy<OpTy>::rewriteOp(
- b, this->options, xferOp, castedDataBuffer, iv);
+ b, this->options, xferOp, castedDataBuffer, iv, loopState);
// If old transfer op has a mask: Set mask on new transfer op.
// Special case: If the mask of the old transfer op is 1D and
rewriter.updateRootInPlace(
newXfer, [&]() { newXfer.maskMutable().assign(mask); });
}
+
+ return loopState.empty() ? Value() : newXfer->getResult(0);
},
/*outOfBoundsCase=*/
[&](OpBuilder &b, Location /*loc*/) {
- Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp,
- castedDataBuffer, iv);
+ return Strategy<OpTy>::handleOutOfBoundsDim(
+ b, xferOp, castedDataBuffer, iv, loopState);
});
- b.create<scf::YieldOp>(loc);
+
+ maybeYieldValue(b, loc, !loopState.empty(), result);
});
- Strategy<OpTy>::cleanup(rewriter, xferOp);
+ Strategy<OpTy>::cleanup(rewriter, xferOp, result);
return success();
}
};
this->fullUnroll = options.unroll;
this->targetRank = options.targetRank;
this->lowerPermutationMaps = options.lowerPermutationMaps;
+ this->lowerTensors = options.lowerTensors;
}
void runOnFunction() override {
options.unroll = fullUnroll;
options.targetRank = targetRank;
options.lowerPermutationMaps = lowerPermutationMaps;
+ options.lowerTensors = lowerTensors;
// Lower permutation maps first.
if (lowerPermutationMaps) {
--- /dev/null
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-tensors=true' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_2d(
+// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<4x9xf32>>
+// CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<4x9xf32>> to memref<4xvector<9xf32>>
+// CHECK: scf.for {{.*}} {
+// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %cst {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32>
+// CHECK: memref.store %[[READ]], %[[CASTED]][%{{.*}}] : memref<4xvector<9xf32>>
+// CHECK: }
+// CHECK: %[[LOADED:.*]] = memref.load %[[ALLOC]][] : memref<vector<4x9xf32>>
+// CHECK: return %[[LOADED]] : vector<4x9xf32>
+func @transfer_read_2d(%A : tensor<?x?xf32>, %base1 : index, %base2 : index)
+ -> (vector<4x9xf32>){
+ %p = constant -42.0: f32
+ %f = vector.transfer_read %A[%base1, %base2], %p {in_bounds = [true, true]}
+ : tensor<?x?xf32>, vector<4x9xf32>
+ return %f : vector<4x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_2d(
+// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<2x3xf32>>
+// CHECK: memref.store {{.*}}, %[[ALLOC]][] : memref<vector<2x3xf32>>
+// CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<2x3xf32>> to memref<2xvector<3xf32>>
+// CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[STATE:.*]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK: %[[LOADED:.*]] = memref.load %[[CASTED]][%{{.*}}] : memref<2xvector<3xf32>>
+// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[LOADED]], %[[STATE]][{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor<?x?xf32>
+// CHECK: scf.yield %[[WRITE]] : tensor<?x?xf32>
+// CHECK: }
+// CHECK: return %[[RESULT]] : tensor<?x?xf32>
+func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
+ %base1 : index, %base2 : index) -> (tensor<?x?xf32>) {
+ %t = vector.transfer_write %vec, %A[%base1, %base2] {in_bounds = [true, true]}
+ : vector<2x3xf32>, tensor<?x?xf32>
+ return %t : tensor<?x?xf32>
+}
+