// types. `BufferCastOp::fold` handles the same type case.
if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
return failure();
- // If types are not cast-compatible, bail.
+ // If types are definitely not cast-compatible, bail.
if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
bufferCast.getType()))
return failure();
- rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
- tensorLoad.memref());
+
+ // We already know that the types are potentially cast-compatible. However
+ // in case the affine maps are different, we may need to use a copy if we go
+ // from dynamic to static offset or stride (the canonicalization cannot know
+ // at this point that it is really cast compatible).
+ auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
+ int64_t sourceOffset, targetOffset;
+ SmallVector<int64_t, 4> sourceStrides, targetStrides;
+ if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
+ failed(getStridesAndOffset(target, targetStrides, targetOffset)))
+ return false;
+ auto dynamicToStatic = [](int64_t a, int64_t b) {
+ return a == MemRefType::getDynamicStrideOrOffset() &&
+ b != MemRefType::getDynamicStrideOrOffset();
+ };
+ if (dynamicToStatic(sourceOffset, targetOffset))
+ return false;
+ for (auto it : zip(sourceStrides, targetStrides))
+ if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
+ return false;
+ return true;
+ };
+
+ auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>();
+ auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>();
+ if (tensorLoadType && bufferCastType &&
+ !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) {
+ MemRefType resultType = bufferCastType;
+ auto loc = bufferCast.getLoc();
+ SmallVector<Value, 4> dynamicOperands;
+ for (int i = 0; i < resultType.getRank(); ++i) {
+ if (resultType.getShape()[i] != ShapedType::kDynamicSize)
+ continue;
+ auto index = rewriter.createOrFold<ConstantIndexOp>(loc, i);
+ Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index);
+ dynamicOperands.push_back(size);
+ }
+ auto copy =
+ rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
+ rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy);
+ rewriter.replaceOp(bufferCast, {copy});
+ } else
+ rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
+ tensorLoad.memref());
return success();
}
};
// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)>
// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// Test case: If the memrefs are cast-compatible, canonicalize.
+// Test case: If the memrefs are definitely cast-compatible, canonicalize to
+// cast.
// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load(
// CHECK-SAME: %[[M:.*]]: memref<?xf32, #[[$OFF_3]]>)
-// CHEKC-SAME: -> memref<?xf32, #[[$OFF_UNK]]> {
+// CHECK-SAME: -> memref<?xf32, #[[$OFF_UNK]]> {
// CHECK-NOT: memref.tensor_load
// CHECK-NOT: memref.buffer_cast
// CHECK: %[[R:.*]] = memref.cast %[[M]]
// CHECK-SAME: memref<?xf32, #[[$OFF_3]]> to memref<?xf32, #[[$OFF_UNK]]>
// CHECK: return %[[R]]
-func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref<?xf32, offset: 3, strides: [1]>)
+func @canonicalize_buffer_cast_of_tensor_load(
+ %arg0: memref<?xf32, offset: 3, strides: [1]>)
-> memref<?xf32, offset: ?, strides: [1]>
{
%0 = memref.tensor_load %arg0 : memref<?xf32, offset: 3, strides: [1]>
// -----
+// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)>
+
+// Test case: If the memrefs are potentially cast-compatible, canonicalize to
+// copy.
+// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_to_copy(
+// CHECK-SAME: %[[M:.*]]: memref<?xf32, #[[$OFF_UNK]]>)
+// CHECK-SAME: -> memref<?xf32, #[[$OFF_3]]> {
+// CHECK-NOT: memref.tensor_load
+// CHECK-NOT: memref.buffer_cast
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref<?xf32, #[[$OFF_UNK]]>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, #[[$OFF_3]]>
+// CHECK: memref.copy %[[M]], %[[ALLOC]]
+// CHECK-SAME: memref<?xf32, #[[$OFF_UNK]]> to memref<?xf32, #[[$OFF_3]]>
+// CHECK: return %[[ALLOC]]
+func @canonicalize_buffer_cast_of_tensor_load_to_copy(
+ %arg0: memref<?xf32, offset: ?, strides: [1]>)
+ -> memref<?xf32, offset: 3, strides: [1]>
+{
+ %0 = memref.tensor_load %arg0 : memref<?xf32, offset: ?, strides: [1]>
+ %1 = memref.buffer_cast %0 : memref<?xf32, offset: 3, strides: [1]>
+ return %1 : memref<?xf32, offset: 3, strides: [1]>
+}
+
+// -----
+
// CHECK-LABEL: func @subview_of_memcast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>