Homogenize the description of the MemRef conversion to the LLVM dialect
authorAlex Zinenko <zinenko@google.com>
Tue, 17 Dec 2019 19:32:19 +0000 (11:32 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 17 Dec 2019 19:32:50 +0000 (11:32 -0800)
The conversion procedure has been updated to reflect the most recent MemRef
descriptor proposal, but the documentation was only updated for the type
conversion, omitting the address computation section. Make sure the two
sections agree.

PiperOrigin-RevId: 286022684

mlir/g3doc/ConversionToLLVMDialect.md

index 3564d4c..634f107 100644 (file)
@@ -91,9 +91,10 @@ memref<1x? x vector<4xf32>> -> !llvm.type<"{ <4 x float>*, <4 x float>*, i64, [1
 ```
 
 If the rank of the memref is unknown at compile time, the Memref is converted to
-an unranked descriptor that contains: 1. a 64-bit integer representing the
-dynamic rank of the memref, followed by 2. a pointer to a ranked memref
-descriptor with the contents listed above.
+an unranked descriptor that contains:
+
+1.  a 64-bit integer representing the dynamic rank of the memref, followed by
+2.  a pointer to a ranked memref descriptor with the contents listed above.
 
 Dynamic ranked memrefs should be used only to pass arguments to external library
 calls that expect a unified memref type. The called functions can parse any
@@ -354,23 +355,24 @@ before the conversion to the LLVM IR dialect:
 
 Within a converted function, a `memref`-typed value is represented by a memref
 _descriptor_, the type of which is the structure type obtained by converting
-from the memref type. This descriptor holds a pointer to a linear buffer storing
-the data, and dynamic sizes of the memref value. It is created by the allocation
-operation and is updated by the conversion operations that may change static
-dimensions into dynamic and vice versa.
+from the memref type. This descriptor holds all the necessary information to
+produce an address of a specific element. In particular, it holds dynamic values
+for static sizes, and they are expected to match at all times.
+
+It is created by the allocation operation and is updated by the conversion
+operations that may change static dimensions into dynamic and vice versa.
 
-Note: LLVM IR conversion does not support `memref`s in non-default memory spaces
-or `memref`s with non-identity layouts.
+**Note**: LLVM IR conversion does not support `memref`s with layouts that are
+not amenable to the strided form.
 
 ### Index Linearization
 
 Accesses to a memref element are transformed into an access to an element of the
 buffer pointed to by the descriptor. The position of the element in the buffer
 is calculated by linearizing memref indices in row-major order (lexically first
-index is the slowest varying, similar to C). The computation of the linear
-address is emitted as arithmetic operation in the LLVM IR dialect. Static sizes
-are introduced as constants. Dynamic sizes are extracted from the memref
-descriptor.
+index is the slowest varying, similar to C, but accounting for strides). The
+computation of the linear address is emitted as arithmetic operation in the LLVM
+IR dialect. Strides are extracted from the memref descriptor.
 
 Accesses to zero-dimensional memref (that are interpreted as pointers to the
 elemental type) are directly converted into `llvm.load` or `llvm.store` without
@@ -385,7 +387,7 @@ An access to a zero-dimensional memref is converted into a plain load:
 %0 = load %m[] : memref<f32>
 
 // after
-%0 = llvm.load %m : !llvm.type<"float*">
+%0 = llvm.load %m : !llvm<"float*">
 ```
 
 An access to a memref with indices:
@@ -397,40 +399,45 @@ An access to a memref with indices:
 is transformed into the equivalent of the following code:
 
 ```mlir
-// obtain the buffer pointer
-%b = llvm.extractvalue %m[0] : !llvm.type<"{float*, i64, i64}">
-
-// obtain the components for the index
-%sub1 = llvm.mlir.constant(1) : !llvm.type<"i64">  // first subscript
-%sz2 = llvm.extractvalue %m[1]
-    : !llvm.type<"{float*, i64, i64}"> // second size (dynamic, second descriptor element)
-%sub2 = llvm.mlir.constant(2) : !llvm.type<"i64">  // second subscript
-%sz3 = llvm.mlir.constant(13) : !llvm.type<"i64">  // third size (static)
-%sub3 = llvm.mlir.constant(3) : !llvm.type<"i64">  // third subscript
-%sz4 = llvm.extractvalue %m[1]
-    : !llvm.type<"{float*, i64, i64}"> // fourth size (dynamic, third descriptor element)
-%sub4 = llvm.mlir.constant(4) : !llvm.type<"i64">  // fourth subscript
-
-// compute the linearized index
-// %sub4 + %sub3 * %sz4 + %sub2 * (%sz3 * %sz4) + %sub1 * (%sz2 * %sz3 * %sz4) =
-// = ((%sub1 * %sz2 + %sub2) * %sz3 + %sub3) * %sz4 + %sub4
-%idx0 = llvm.mul %sub1, %sz2 : !llvm.type<"i64">
-%idx1 = llvm.add %idx0, %sub : !llvm.type<"i64">
-%idx2 = llvm.mul %idx1, %sz3 : !llvm.type<"i64">
-%idx3 = llvm.add %idx2, %sub3 : !llvm.type<"i64">
-%idx4 = llvm.mul %idx3, %sz4 : !llvm.type<"i64">
-%idx5 = llvm.add %idx4, %sub4 : !llvm.type<"i64">
-
-// obtain the element address
-%a = llvm.getelementptr %b[%idx5] : (!llvm.type<"float*">, !llvm.type<"i64">) -> !llvm.type<"float*">
-
-// perform the actual load
-%0 = llvm.load %a : !llvm.type<"float*">
+// Compute the linearized index from strides. Each block below extracts one
+// stride from the descriptor, multipllies it with the index and accumulates
+// the total offset.
+%stride1 = llvm.extractvalue[4, 0] : !llvm<"{float*, float*, i64, i64[4], i64[4]}">
+%idx1 = llvm.mlir.constant(1 : index) !llvm.i64
+%addr1 = muli %stride1, %idx1 : !llvm.i64
+
+%stride2 = llvm.extractvalue[4, 1] : !llvm<"{float*, float*, i64, i64[4], i64[4]}">
+%idx2 = llvm.mlir.constant(2 : index) !llvm.i64
+%addr2 = muli %stride2, %idx2 : !llvm.i64
+%addr3 = addi %addr1, %addr2 : !llvm.i64
+
+%stride3 = llvm.extractvalue[4, 2] : !llvm<"{float*, float*, i64, i64[4], i64[4]}">
+%idx3 = llvm.mlir.constant(3 : index) !llvm.i64
+%addr4 = muli %stride3, %idx3 : !llvm.i64
+%addr5 = addi %addr3, %addr4 : !llvm.i64
+
+%stride4 = llvm.extractvalue[4, 3] : !llvm<"{float*, float*, i64, i64[4], i64[4]}">
+%idx4 = llvm.mlir.constant(4 : index) !llvm.i64
+%addr6 = muli %stride4, %idx4 : !llvm.i64
+%addr7 = addi %addr5, %addr6 : !llvm.i64
+
+// Add the linear offset to the address.
+%offset = llvm.extractvalue[2] : !llvm<"{float*, float*, i64, i64[4], i64[4]}">
+%addr8 = addi %addr7, %offset : !llvm.i64
+
+// Obtain the aligned pointer.
+%aligned = llvm.extractvalue[1] : !llvm<"{float*, float*, i64, i64[4], i64[4]}">
+
+// Get the address of the data pointer.
+%ptr = llvm.getelementptr %aligned[%addr8]
+    : !llvm<"{float*, float*, i64, i64[4], i64[4]}"> -> !llvm<"float*">
+
+// Perform the actual load.
+%0 = llvm.load %ptr : !llvm<"float*">
 ```
 
-In practice, the subscript and size extraction will be interleaved with the
-linear index computation. For stores, the address computation code is identical
-and only the actual store operation is different.
+For stores, the address computation code is identical and only the actual store
+operation is different.
 
 Note: the conversion does not perform any sort of common subexpression
 elimination when emitting memref accesses.