OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
- Type sourceType = lookup(bvm, castOp.source()).getType();
+ // If castOp is not inPlace, allocate a new buffer.
+ auto inPlace = getInPlace(castOp->getResult(0));
+ Value newBuffer;
+ if (inPlace != InPlaceSpec::True) {
+ Location loc = castOp.getLoc();
+ // Alloc a copy for `writeOp.source()`, it will become the result buffer.
+ newBuffer = createNewAllocDeallocPairForShapedValue(b, loc, castOp.source(),
+ aliasInfo);
+ if (!isInitTensorOp(castOp.source())) {
+ // Set insertion point now that potential alloc/dealloc are introduced.
+ b.setInsertionPoint(castOp);
+ b.create<CopyOp>(loc, lookup(bvm, castOp.source()), newBuffer);
+ }
+ } else {
+ // InPlace write will result in memref.tensor_load(x) which must
+ // canonicalize away with one of it uses.
+ newBuffer = lookup(bvm, castOp.source());
+ assert(newBuffer && "missing buffer");
+ }
+
+ Type sourceType = newBuffer.getType();
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
assert(rankedMemRefType || unrankedMemRefType);
: ArrayRef<AffineMap>{};
Type memRefType = getContiguousOrUnrankedMemRefType(
castOp.getResult().getType(), affineMaps, memorySpace);
- Value res = b.create<memref::CastOp>(castOp.getLoc(), memRefType,
- lookup(bvm, castOp.source()));
+ Value res = b.create<memref::CastOp>(castOp.getLoc(), memRefType, newBuffer);
aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
map(bvm, castOp.getResult(), res);
return success();
}
return %0 : tensor<128x192xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @tensor_cast_not_in_place(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32{{.*}}>, %[[B:.*]]: memref<?xf32{{.*}}>
+// CHECK: %[[alloc:.*]] = memref.alloc
+// CHECK: linalg.copy(%[[A]], %[[alloc]])
+// CHECK: %[[cast:.*]] = memref.cast %[[alloc]]
+func @tensor_cast_not_in_place(
+ %A : tensor<?xf32> {linalg.inplaceable = true},
+ %B : tensor<?xf32>, %idx: index)
+ -> (tensor<?xf32>)
+{
+ %r0 = tensor.cast %A : tensor<?xf32> to tensor<4xf32>
+ %r1 = tensor.insert_slice %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+ return %r1 : tensor<?xf32>
+}
+