return rewriter.notifyMatchFailure(
op, "UnrankedMemRefType is not supported.");
}
+ MemRefType memrefType = type.cast<MemRefType>();
+ MemRefLayoutAttrInterface layout;
+ auto allocType =
+ MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
+ layout, memrefType.getMemorySpace());
+ // Since this implementation always allocates, certain result types of the
+ // clone op cannot be lowered.
+ if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
+ return failure();
// Transform a clone operation into alloc + copy operation and pay
// attention to the shape dimensions.
- MemRefType memrefType = type.cast<MemRefType>();
Location loc = op->getLoc();
SmallVector<Value, 4> dynamicOperands;
for (int i = 0; i < memrefType.getRank(); ++i) {
Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.input(), size);
dynamicOperands.push_back(dim);
}
- Value alloc = rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
- dynamicOperands);
+
+ // Allocate a memref with identity layout.
+ Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
+ dynamicOperands);
+ // Cast the allocation to the specified type if needed.
+ if (memrefType != allocType)
+ alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
+ rewriter.replaceOp(op, alloc);
rewriter.create<memref::CopyOp>(loc, op.input(), alloc);
return success();
}
// CHECK-LABEL: @conversion_static
func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> {
- %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
- memref.dealloc %arg0 : memref<2xf32>
- return %0 : memref<2xf32>
+ %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
+ memref.dealloc %arg0 : memref<2xf32>
+ return %0 : memref<2xf32>
}
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-LABEL: @conversion_dynamic
func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
- %1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
- memref.dealloc %arg0 : memref<?xf32>
- return %1 : memref<?xf32>
+ %1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
+ memref.dealloc %arg0 : memref<?xf32>
+ return %1 : memref<?xf32>
}
// CHECK: %[[CONST:.*]] = arith.constant
func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
- %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
- memref.dealloc %arg0 : memref<*xf32>
- return %1 : memref<*xf32>
+ %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
+ memref.dealloc %arg0 : memref<*xf32>
+ return %1 : memref<*xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// CHECK-LABEL: func @conversion_with_layout_map(
+// CHECK-SAME: %[[ARG:.*]]: memref<?xf32, #[[$MAP]]>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = memref.dim %[[ARG]], %[[C0]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32>
+// CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC]] : memref<?xf32> to memref<?xf32, #[[$MAP]]>
+// CHECK: memref.copy
+// CHECK: memref.dealloc
+// CHECK: return %[[CASTED]]
+func @conversion_with_layout_map(%arg0 : memref<?xf32, #map>) -> memref<?xf32, #map> {
+ %1 = bufferization.clone %arg0 : memref<?xf32, #map> to memref<?xf32, #map>
+ memref.dealloc %arg0 : memref<?xf32, #map>
+ return %1 : memref<?xf32, #map>
+}
+
+// -----
+
+// This bufferization.clone cannot be lowered because a buffer with this layout
+// map cannot be allocated (or casted to).
+
+#map2 = affine_map<(d0)[s0] -> (d0 * 10 + s0)>
+func @conversion_with_invalid_layout_map(%arg0 : memref<?xf32, #map2>)
+ -> memref<?xf32, #map2> {
+// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
+ %1 = bufferization.clone %arg0 : memref<?xf32, #map2> to memref<?xf32, #map2>
+ memref.dealloc %arg0 : memref<?xf32, #map2>
+ return %1 : memref<?xf32, #map2>
}