The spec allows zero-dimensional memrefs to exist and treats them essentially
as single-element buffers. Unlike single-dimensional memrefs of static shape
<1xTy>, zero-dimensional memrefs do not require indices to access the only
element they store. Add support of zero-dimensional memrefs to the LLVM IR
conversion. In particular, such memrefs are converted into bare pointers, and
accesses to them are converted to bare loads and stores, without the overhead
of `getelementptr %buffer, 0`.
PiperOrigin-RevId:
240579456
if they contain dynamic sizes. In the latter case, the first element of the
structure is a pointer to the converted (using these rules) memref element type,
followed by as many elements as the memref has dynamic sizes. The type of each
-of these size arguments will be the LLVM type that results from converting
-the MLIR `index` type.
+of these size arguments will be the LLVM type that results from converting the
+MLIR `index` type. Zero-dimensional memrefs are treated as pointers to the
+elemental type.
Examples:
sizes are introduced as constants. Dynamic sizes are extracted from the memref
descriptor.
+Accesses to zero-dimensional memref (that are interpreted as pointers to the
+elemental type) are directly converted into `llvm.load` or `llvm.store` without
+any pointer manipulations.
+
Examples:
+An access to a zero-dimensional memref is converted into a plain load:
+
+```mlir {.mlir}
+// before
+%0 = load %m[] : memref<f32>
+
+// after
+%0 = "llvm.load"(%m) : (!llvm.type<"float*">) -> (!llvm.type<"float">)
+```
+
+An access to a memref with indices:
+
```mlir {.mlir}
%0 = load %m[1,2,3,4] : memref<10x?x13x?xf32>
```
MemRefType type = allocOp.getType();
// Get actual sizes of the memref as values: static sizes are constant
- // values and dynamic sizes are passed to 'alloc' as operands.
+ // values and dynamic sizes are passed to 'alloc' as operands. In case of
+ // zero-dimensional memref, assume a scalar (size 1).
SmallVector<Value *, 4> sizes;
auto numOperands = allocOp.getNumOperands();
sizes.reserve(numOperands);
for (int64_t s : type.getShape())
sizes.push_back(s == -1 ? operands[i++]
: createIndexConstant(rewriter, op->getLoc(), s));
- assert(!sizes.empty() && "zero-dimensional allocation");
+ if (sizes.empty())
+ sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1));
// Compute the total number of memref elements.
Value *cumulativeSize = sizes.front();
ArrayRef<Value *>{dataPtr, subscript},
ArrayRef<NamedAttribute>{});
}
- // This is a getElementPtr variant, where the value is a direct raw pointer
+ // This is a getElementPtr variant, where the value is a direct raw pointer.
+ // If a shape is empty, we are dealing with a zero-dimensional memref. Return
+ // the pointer unmodified in this case. Otherwise, linearize subscripts to
+ // obtain the offset with respect to the base pointer. Use this offset to
+ // compute and return the element pointer.
Value *getRawElementPtr(Location loc, Type elementTypePtr,
ArrayRef<int64_t> shape, Value *rawDataPtr,
ArrayRef<Value *> indices,
FuncBuilder &rewriter) const {
+ if (shape.empty())
+ return rawDataPtr;
+
SmallVector<Value *, 4> sizes;
for (int64_t s : shape) {
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
return %static : memref<32x18xf32>
}
+// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"float*"> {
+func @zero_d_alloc() -> memref<f32> {
+// CHECK-NEXT: %0 = "llvm.constant"() {value: 1 : index} : () -> !llvm<"i64">
+// CHECK-NEXT: %1 = "llvm.constant"() {value: 4 : index} : () -> !llvm<"i64">
+// CHECK-NEXT: %2 = "llvm.mul"(%0, %1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
+// CHECK-NEXT: %3 = "llvm.call"(%2) {callee: @malloc : (!llvm<"i64">) -> !llvm<"i8*">} : (!llvm<"i64">) -> !llvm<"i8*">
+// CHECK-NEXT: %4 = "llvm.bitcast"(%3) : (!llvm<"i8*">) -> !llvm<"float*">
+ %0 = alloc() : memref<f32>
+ return %0 : memref<f32>
+}
+
+// CHECK-LABEL: func @zero_d_dealloc(%arg0: !llvm<"float*">) {
+func @zero_d_dealloc(%arg0: memref<f32>) {
+// CHECK-NEXT: %0 = "llvm.bitcast"(%arg0) : (!llvm<"float*">) -> !llvm<"i8*">
+// CHECK-NEXT: "llvm.call"(%0) {callee: @free : (!llvm<"i8*">) -> ()} : (!llvm<"i8*">) -> ()
+ dealloc %arg0 : memref<f32>
+ return
+}
// CHECK-LABEL: func @mixed_alloc(%arg0: !llvm<"i64">, %arg1: !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }"> {
func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
return
}
+// CHECK-LABEL: func @zero_d_load(%arg0: !llvm<"float*">) -> !llvm<"float"> {
+func @zero_d_load(%arg0: memref<f32>) -> f32 {
+// CHECK-NEXT: %0 = "llvm.load"(%arg0) : (!llvm<"float*">) -> !llvm<"float">
+ %0 = load %arg0[] : memref<f32>
+ return %0 : f32
+}
+
// CHECK-LABEL: func @static_load
func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
// CHECK-NEXT: %0 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">
return
}
+// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"float*">, %arg1: !llvm<"float">) {
+func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
+// CHECK-NEXT: "llvm.store"(%arg1, %arg0) : (!llvm<"float">, !llvm<"float*">) -> ()
+ store %arg1, %arg0[] : memref<f32>
+ return
+}
+
// CHECK-LABEL: func @static_store
func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
// CHECK-NEXT: %0 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">