Emit LLVM IR equivalent of sizeof when lowering alloc operations
authorAlex Zinenko <zinenko@google.com>
Fri, 11 Oct 2019 13:22:40 +0000 (06:22 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 11 Oct 2019 13:33:26 +0000 (06:33 -0700)
Originally, the lowering of `alloc` operations has been computing the number of
bytes to allocate when lowering based on the properties of MLIR type. This does
not take into account type legalization that happens when compiling LLVM IR
down to target assembly. This legalization can widen the type, potentially
leading to out-of-bounds accesses to `alloc`ed data due to mismatches between
address computation that takes the widening into account and allocation that
does not. Use the LLVM IR's equivalent of `sizeof` to compute the number of
bytes to be allocated:
  %0 = getelementptr %type* null, %indexType 0
  %1 = ptrtoint %type* %0 to %indexType
adapted from
http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt

PiperOrigin-RevId: 274159900

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir
mlir/test/Examples/Toy/Ch5/lowering.toy

index 4282079..0c162cb 100644 (file)
@@ -658,21 +658,24 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
           op->getLoc(), getIndexType(),
           ArrayRef<Value *>{cumulativeSize, sizes[i]});
 
-    // Compute the total amount of bytes to allocate.
+    // Compute the size of an individual element. This emits the MLIR equivalent
+    // of the following sizeof(...) implementation in LLVM IR:
+    //   %0 = getelementptr %elementType* null, %indexType 1
+    //   %1 = ptrtoint %elementType* %0 to %indexType
+    // which is a common pattern of getting the size of a type in bytes.
     auto elementType = type.getElementType();
-    assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) &&
-           "invalid memref element type");
-    uint64_t elementSize = 0;
-    if (auto vectorType = elementType.dyn_cast<VectorType>())
-      elementSize = vectorType.getNumElements() *
-                    llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
-    else
-      elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
+    auto convertedPtrType =
+        lowering.convertType(elementType).cast<LLVM::LLVMType>().getPointerTo();
+    auto nullPtr =
+        rewriter.create<LLVM::NullOp>(op->getLoc(), convertedPtrType);
+    auto one = createIndexConstant(rewriter, op->getLoc(), 1);
+    auto gep = rewriter.create<LLVM::GEPOp>(op->getLoc(), convertedPtrType,
+                                            ArrayRef<Value *>{nullPtr, one});
+    auto elementSize =
+        rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), getIndexType(), gep);
     cumulativeSize = rewriter.create<LLVM::MulOp>(
         op->getLoc(), getIndexType(),
-        ArrayRef<Value *>{
-            cumulativeSize,
-            createIndexConstant(rewriter, op->getLoc(), elementSize)});
+        ArrayRef<Value *>{cumulativeSize, elementSize});
 
     // Insert the `malloc` declaration if it is not already present.
     auto module = op->getParentOfType<ModuleOp>();
index 3b3a011..4db4bf9 100644 (file)
@@ -22,8 +22,11 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
 // CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, i64 }"> {
 func @zero_d_alloc() -> memref<f32> {
 // CHECK-NEXT:  llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT:  llvm.mlir.constant(4 : index) : !llvm.i64
-// CHECK-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
+// CHECK-NEXT:  %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK-NEXT:  %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK-NEXT:  %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
+// CHECK-NEXT:  llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
 // CHECK-NEXT:  llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*">
 // CHECK-NEXT:  %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
 // CHECK-NEXT:  llvm.mlir.undef : !llvm<"{ float*, i64 }">
@@ -50,8 +53,11 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
 //  CHECK-NEXT:  %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
 //  CHECK-NEXT:  llvm.mul %[[M]], %[[c42]] : !llvm.i64
 //  CHECK-NEXT:  %[[sz:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
-//  CHECK-NEXT:  llvm.mlir.constant(4 : index) : !llvm.i64
-//  CHECK-NEXT:  %[[sz_bytes:.*]] = llvm.mul %[[sz]], %{{.*}} : !llvm.i64
+//  CHECK-NEXT:  %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
+//  CHECK-NEXT:  %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+//  CHECK-NEXT:  %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+//  CHECK-NEXT:  %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
+//  CHECK-NEXT:  %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64
 //  CHECK-NEXT:  llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm<"i8*">
 //  CHECK-NEXT:  llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
 //  CHECK-NEXT:  llvm.mlir.undef : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
@@ -87,8 +93,11 @@ func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
 //       CHECK:   %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {
 func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
 //  CHECK-NEXT:  %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64
-//  CHECK-NEXT:  llvm.mlir.constant(4 : index) : !llvm.i64
-//  CHECK-NEXT:  %[[sz_bytes:.*]] = llvm.mul %[[sz]], %{{.*}} : !llvm.i64
+//  CHECK-NEXT:  %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
+//  CHECK-NEXT:  %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+//  CHECK-NEXT:  %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+//  CHECK-NEXT:  %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
+//  CHECK-NEXT:  %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64
 //  CHECK-NEXT:  llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm<"i8*">
 //  CHECK-NEXT:  llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
 //  CHECK-NEXT:  llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
@@ -118,13 +127,16 @@ func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
 
 // CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {
 func @static_alloc() -> memref<32x18xf32> {
-// CHECK-NEXT:  %0 = llvm.mlir.constant(32 : index) : !llvm.i64
-// CHECK-NEXT:  %1 = llvm.mlir.constant(18 : index) : !llvm.i64
-// CHECK-NEXT:  %2 = llvm.mul %0, %1 : !llvm.i64
-// CHECK-NEXT:  %3 = llvm.mlir.constant(4 : index) : !llvm.i64
-// CHECK-NEXT:  %4 = llvm.mul %2, %3 : !llvm.i64
-// CHECK-NEXT:  %5 = llvm.call @malloc(%4) : (!llvm.i64) -> !llvm<"i8*">
-// CHECK-NEXT:  %6 = llvm.bitcast %5 : !llvm<"i8*"> to !llvm<"float*">
+// CHECK-NEXT:  %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
+// CHECK-NEXT:  %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
+// CHECK-NEXT:  %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64
+// CHECK-NEXT:  %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
+// CHECK-NEXT:  %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK-NEXT:  %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+//  CHECK-NEXT:  %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
+// CHECK-NEXT:  %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
+// CHECK-NEXT:  %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*">
+// CHECK-NEXT:  llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*">
  %0 = alloc() : memref<32x18xf32>
  return %0 : memref<32x18xf32>
 }
index 3c198a6..6f16437 100644 (file)
@@ -6,7 +6,7 @@ def multiply_transpose(a, b) {
 }
 
 # CHECK: define void @main() {
-# CHECK:  %1 = call i8* @malloc(i64 48)
+# CHECK:  %1 = call i8* @malloc(i64 mul (i64 ptrtoint (double* getelementptr (double, double* null, i64 1) to i64), i64 6))
 def main() {
   var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
   var b<2, 3> = [1, 2, 3, 4, 5, 6];