LLVM IR Conversion: support zero-dimensional memrefs
authorAlex Zinenko <zinenko@google.com>
Wed, 27 Mar 2019 16:39:31 +0000 (09:39 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:45:26 +0000 (17:45 -0700)
The spec allows zero-dimensional memrefs to exist and treats them essentially
as single-element buffers.  Unlike single-dimensional memrefs of static shape
<1xTy>, zero-dimensional memrefs do not require indices to access the only
element they store.  Add support of zero-dimensional memrefs to the LLVM IR
conversion.  In particular, such memrefs are converted into bare pointers, and
accesses to them are converted to bare loads and stores, without the overhead
of `getelementptr %buffer, 0`.

PiperOrigin-RevId: 240579456

mlir/g3doc/ConversionToLLVMDialect.md
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/test/LLVMIR/convert-memref-ops.mlir

index 8a2b46c..418dfbf 100644 (file)
@@ -54,8 +54,9 @@ pointer types if they are fully statically shaped; or to LLVM IR structure types
 if they contain dynamic sizes. In the latter case, the first element of the
 structure is a pointer to the converted (using these rules) memref element type,
 followed by as many elements as the memref has dynamic sizes. The type of each
-of these size arguments will be the LLVM type that results from converting 
-the MLIR `index` type.
+of these size arguments will be the LLVM type that results from converting the
+MLIR `index` type. Zero-dimensional memrefs are treated as pointers to the
+elemental type.
 
 Examples:
 
@@ -294,8 +295,24 @@ address is emitted as arithmetic instructions in the LLVM IR dialect. Static
 sizes are introduced as constants. Dynamic sizes are extracted from the memref
 descriptor.
 
+Accesses to zero-dimensional memref (that are interpreted as pointers to the
+elemental type) are directly converted into `llvm.load` or `llvm.store` without
+any pointer manipulations.
+
 Examples:
 
+An access to a zero-dimensional memref is converted into a plain load:
+
+```mlir {.mlir}
+// before
+%0 = load %m[] : memref<f32>
+
+// after
+%0 = "llvm.load"(%m) : (!llvm.type<"float*">) -> (!llvm.type<"float">)
+```
+
+An access to a memref with indices:
+
 ```mlir {.mlir}
 %0 = load %m[1,2,3,4] : memref<10x?x13x?xf32>
 ```
index 3f17c67..d5f430f 100644 (file)
@@ -558,7 +558,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
     MemRefType type = allocOp.getType();
 
     // Get actual sizes of the memref as values: static sizes are constant
-    // values and dynamic sizes are passed to 'alloc' as operands.
+    // values and dynamic sizes are passed to 'alloc' as operands.  In case of
+    // zero-dimensional memref, assume a scalar (size 1).
     SmallVector<Value *, 4> sizes;
     auto numOperands = allocOp.getNumOperands();
     sizes.reserve(numOperands);
@@ -566,7 +567,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
     for (int64_t s : type.getShape())
       sizes.push_back(s == -1 ? operands[i++]
                               : createIndexConstant(rewriter, op->getLoc(), s));
-    assert(!sizes.empty() && "zero-dimensional allocation");
+    if (sizes.empty())
+      sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1));
 
     // Compute the total number of memref elements.
     Value *cumulativeSize = sizes.front();
@@ -882,11 +884,18 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
                                         ArrayRef<Value *>{dataPtr, subscript},
                                         ArrayRef<NamedAttribute>{});
   }
-  // This is a getElementPtr variant, where the value is a direct raw pointer
+  // This is a getElementPtr variant, where the value is a direct raw pointer.
+  // If a shape is empty, we are dealing with a zero-dimensional memref. Return
+  // the pointer unmodified in this case.  Otherwise, linearize subscripts to
+  // obtain the offset with respect to the base pointer.  Use this offset to
+  // compute and return the element pointer.
   Value *getRawElementPtr(Location loc, Type elementTypePtr,
                           ArrayRef<int64_t> shape, Value *rawDataPtr,
                           ArrayRef<Value *> indices,
                           FuncBuilder &rewriter) const {
+    if (shape.empty())
+      return rawDataPtr;
+
     SmallVector<Value *, 4> sizes;
     for (int64_t s : shape) {
       sizes.push_back(this->createIndexConstant(rewriter, loc, s));
index 266e499..b9d518c 100644 (file)
@@ -12,6 +12,24 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
   return %static : memref<32x18xf32>
 }
 
+// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"float*"> {
+func @zero_d_alloc() -> memref<f32> {
+// CHECK-NEXT:   %0 = "llvm.constant"() {value: 1 : index} : () -> !llvm<"i64">
+// CHECK-NEXT:   %1 = "llvm.constant"() {value: 4 : index} : () -> !llvm<"i64">
+// CHECK-NEXT:   %2 = "llvm.mul"(%0, %1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
+// CHECK-NEXT:   %3 = "llvm.call"(%2) {callee: @malloc : (!llvm<"i64">) -> !llvm<"i8*">} : (!llvm<"i64">) -> !llvm<"i8*">
+// CHECK-NEXT:   %4 = "llvm.bitcast"(%3) : (!llvm<"i8*">) -> !llvm<"float*">
+  %0 = alloc() : memref<f32>
+  return %0 : memref<f32>
+}
+
+// CHECK-LABEL: func @zero_d_dealloc(%arg0: !llvm<"float*">) {
+func @zero_d_dealloc(%arg0: memref<f32>) {
+// CHECK-NEXT:   %0 = "llvm.bitcast"(%arg0) : (!llvm<"float*">) -> !llvm<"i8*">
+// CHECK-NEXT:   "llvm.call"(%0) {callee: @free : (!llvm<"i8*">) -> ()} : (!llvm<"i8*">) -> ()
+  dealloc %arg0 : memref<f32>
+  return
+}
 
 // CHECK-LABEL: func @mixed_alloc(%arg0: !llvm<"i64">, %arg1: !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }"> {
 func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
@@ -85,6 +103,13 @@ func @static_dealloc(%static: memref<10x8xf32>) {
   return
 }
 
+// CHECK-LABEL: func @zero_d_load(%arg0: !llvm<"float*">) -> !llvm<"float"> {
+func @zero_d_load(%arg0: memref<f32>) -> f32 {
+// CHECK-NEXT:   %0 = "llvm.load"(%arg0) : (!llvm<"float*">) -> !llvm<"float">
+  %0 = load %arg0[] : memref<f32>
+  return %0 : f32
+}
+
 // CHECK-LABEL: func @static_load
 func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
 // CHECK-NEXT: %0 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">
@@ -123,6 +148,13 @@ func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
   return
 }
 
+// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"float*">, %arg1: !llvm<"float">) {
+func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
+// CHECK-NEXT:   "llvm.store"(%arg1, %arg0) : (!llvm<"float">, !llvm<"float*">) -> ()
+  store %arg1, %arg0[] : memref<f32>
+  return
+}
+
 // CHECK-LABEL: func @static_store
 func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
 // CHECK-NEXT: %0 = "llvm.constant"() {value: 10 : index} : () -> !llvm<"i64">