[mlir] Require struct indices in LLVM::GEPOp to be constant
authorAlex Zinenko <zinenko@google.com>
Thu, 6 Jan 2022 22:30:15 +0000 (23:30 +0100)
committerAlex Zinenko <zinenko@google.com>
Fri, 7 Jan 2022 08:56:05 +0000 (09:56 +0100)
Recent commits added a possibility for indices in LLVM dialect GEP operations
to be supplied directly as constant attributes to ensure they remain such until
translation to LLVM IR happens. Make this required for indexing into LLVM
struct types to match LLVM IR requirements, otherwise the translation would
assert on constructing such IR.

For better compatibility with MLIR-style operation construction interface,
allow GEP operations to be constructed programmatically using Values pointing
to known constant operations as struct indices.

Depends On D116758

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D116759

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Target/LLVMIR/llvmir.mlir

index 3d6f7f3..328ff9d 100644 (file)
@@ -350,6 +350,9 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
     constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
   }];
   let hasFolder = 1;
+  let verifier = [{
+    return ::verify(*this);
+  }];
 }
 
 def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
index 56b14a7..0f6c91f 100644 (file)
@@ -360,6 +360,58 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
 // Code for LLVM::GEPOp.
 //===----------------------------------------------------------------------===//
 
+/// Populates `indices` with positions of GEP indices that would correspond to
+/// LLVMStructTypes potentially nested in the given type. The type currently
+/// visited gets `currentIndex` and LLVM container types are visited
+/// recursively. The recursion is bounded and takes care of recursive types by
+/// means of the `visited` set.
+static void recordStructIndices(Type type, unsigned currentIndex,
+                                SmallVectorImpl<unsigned> &indices,
+                                SmallVectorImpl<unsigned> *structSizes,
+                                SmallPtrSet<Type, 4> &visited) {
+  if (visited.contains(type))
+    return;
+
+  visited.insert(type);
+
+  llvm::TypeSwitch<Type>(type)
+      .Case<LLVMStructType>([&](LLVMStructType structType) {
+        indices.push_back(currentIndex);
+        if (structSizes)
+          structSizes->push_back(structType.getBody().size());
+        for (Type elementType : structType.getBody())
+          recordStructIndices(elementType, currentIndex + 1, indices,
+                              structSizes, visited);
+      })
+      .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
+            LLVMArrayType>([&](auto containerType) {
+        recordStructIndices(containerType.getElementType(), currentIndex + 1,
+                            indices, structSizes, visited);
+      });
+}
+
+/// Populates `indices` with positions of GEP indices that correspond to
+/// LLVMStructTypes potentially nested in the given `baseGEPType`, which must
+/// be either an LLVMPointer type or a vector thereof. If `structSizes` is
+/// provided, it is populated with sizes of the indexed structs for bounds
+/// verification purposes.
+static void
+findKnownStructIndices(Type baseGEPType, SmallVectorImpl<unsigned> &indices,
+                       SmallVectorImpl<unsigned> *structSizes = nullptr) {
+  Type type = baseGEPType;
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    type = vectorType.getElementType();
+  if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
+    type = scalableVectorType.getElementType();
+  if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
+    type = fixedVectorType.getElementType();
+
+  Type pointeeType = type.cast<LLVMPointerType>().getElementType();
+  SmallPtrSet<Type, 4> visited;
+  recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes,
+                      visited);
+}
+
 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
                   Value basePtr, ValueRange operands,
                   ArrayRef<NamedAttribute> attributes) {
@@ -372,11 +424,58 @@ void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
                   Value basePtr, ValueRange indices,
                   ArrayRef<int32_t> structIndices,
                   ArrayRef<NamedAttribute> attributes) {
+  SmallVector<Value> remainingIndices;
+  SmallVector<int32_t> updatedStructIndices(structIndices.begin(),
+                                            structIndices.end());
+  SmallVector<unsigned> structRelatedPositions;
+  findKnownStructIndices(basePtr.getType(), structRelatedPositions);
+
+  SmallVector<unsigned> operandsToErase;
+  for (unsigned pos : structRelatedPositions) {
+    // GEP may not be indexing as deep as some structs are located.
+    if (pos >= structIndices.size())
+      continue;
+
+    // If the index is already static, it's fine.
+    if (structIndices[pos] != kDynamicIndex)
+      continue;
+
+    // Find the corresponding operand.
+    unsigned operandPos =
+        std::count(structIndices.begin(), std::next(structIndices.begin(), pos),
+                   kDynamicIndex);
+
+    // Extract the constant value from the operand and put it into the attribute
+    // instead.
+    APInt staticIndexValue;
+    bool matched =
+        matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue));
+    (void)matched;
+    assert(matched && "index into a struct must be a constant");
+    assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) &&
+           "struct index underflows 32-bit integer");
+    assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) &&
+           "struct index overflows 32-bit integer");
+    auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue());
+    updatedStructIndices[pos] = staticIndex;
+    operandsToErase.push_back(operandPos);
+  }
+
+  for (unsigned i = 0, e = indices.size(); i < e; ++i) {
+    if (llvm::find(operandsToErase, i) == operandsToErase.end())
+      remainingIndices.push_back(indices[i]);
+  }
+
+  assert(remainingIndices.size() == static_cast<size_t>(llvm::count(
+                                        updatedStructIndices, kDynamicIndex)) &&
+         "exected as many index operands as dynamic index attr elements");
+
   result.addTypes(resultType);
   result.addAttributes(attributes);
-  result.addAttribute("structIndices", builder.getI32TensorAttr(structIndices));
+  result.addAttribute("structIndices",
+                      builder.getI32TensorAttr(updatedStructIndices));
   result.addOperands(basePtr);
-  result.addOperands(indices);
+  result.addOperands(remainingIndices);
 }
 
 static ParseResult
@@ -417,6 +516,27 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
                         });
 }
 
+LogicalResult verify(LLVM::GEPOp gepOp) {
+  SmallVector<unsigned> indices;
+  SmallVector<unsigned> structSizes;
+  findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes);
+  for (unsigned i = 0, e = indices.size(); i < e; ++i) {
+    unsigned index = indices[i];
+    // GEP may not be indexing as deep as some structs nested in the type.
+    if (index >= gepOp.getStructIndices().getNumElements())
+      continue;
+
+    int32_t staticIndex = gepOp.getStructIndices().getValues<int32_t>()[index];
+    if (staticIndex == LLVM::GEPOp::kDynamicIndex)
+      return gepOp.emitOpError() << "expected index " << index
+                                 << " indexing a struct to be constant";
+    if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i])
+      return gepOp.emitOpError()
+             << "index " << index << " indexing a struct is out of bounds";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Builder, printer and parser for for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//
index 321e619..be16764 100644 (file)
@@ -501,8 +501,7 @@ func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
 // CHECK: [[STRUCT_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]]
 // CHECK-SAME: !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, i64)>>
 // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: [[C3_I32:%.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], [[C3_I32]]]
+// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], 3]
 // CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]]
 // CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]]
 // CHECK: [[C1_:%.*]] = llvm.mlir.constant(1 : index) : i64
index 70ba47d..d790b2b 100644 (file)
@@ -547,12 +547,11 @@ func @dim_of_unranked(%unranked: memref<*xi32>) -> index {
 // CHECK: %[[ZERO_D_DESC:.*]] = llvm.bitcast %[[RANKED_DESC]]
 // CHECK-SAME:   : !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<i32>, ptr<i32>, i64)>>
 
-// CHECK: %[[C2_i32:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK: %[[C0_:.*]] = llvm.mlir.constant(0 : index) : i64
 
 // CHECK: %[[OFFSET_PTR:.*]] = llvm.getelementptr %[[ZERO_D_DESC]]{{\[}}
-// CHECK-SAME:   %[[C0_]], %[[C2_i32]]] : (!llvm.ptr<struct<(ptr<i32>, ptr<i32>,
-// CHECK-SAME:   i64)>>, i64, i32) -> !llvm.ptr<i64>
+// CHECK-SAME:   %[[C0_]], 2] : (!llvm.ptr<struct<(ptr<i32>, ptr<i32>,
+// CHECK-SAME:   i64)>>, i64) -> !llvm.ptr<i64>
 
 // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: %[[INDEX_INC:.*]] = llvm.add %[[C1]], %{{.*}} : i64
index effc9be..ea68dc9 100644 (file)
@@ -10,7 +10,7 @@ spv.func @access_chain() "None" {
   %0 = spv.Constant 1: i32
   %1 = spv.Variable : !spv.ptr<!spv.struct<(f32, !spv.array<4xf32>)>, Function>
   // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr<struct<packed (f32, array<4 x f32>)>>, i32, i32, i32) -> !llvm.ptr<f32>
+  // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], 1, %[[ONE]]] : (!llvm.ptr<struct<packed (f32, array<4 x f32>)>>, i32, i32) -> !llvm.ptr<f32>
   %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<(f32, !spv.array<4xf32>)>, Function>, i32, i32
   spv.Return
 }
index 41919ce..82561a7 100644 (file)
@@ -1234,3 +1234,19 @@ func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
   nvvm.cp.async.shared.global %arg0, %arg1, 32
   return
 }
+
+// -----
+
+func @gep_struct_variable(%arg0: !llvm.ptr<struct<(i32)>>, %arg1: i32, %arg2: i32) {
+  // expected-error @below {{op expected index 1 indexing a struct to be constant}}
+  llvm.getelementptr %arg0[%arg1, %arg1] : (!llvm.ptr<struct<(i32)>>, i32, i32) -> !llvm.ptr<i32>
+  return
+}
+
+// -----
+
+func @gep_out_of_bounds(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx: i64) {
+  // expected-error @below {{index 2 indexing a struct is out of bounds}}
+  llvm.getelementptr %ptr[%idx, 1, 3] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
+  return
+}
index 04a65f8..c4a5434 100644 (file)
@@ -1444,7 +1444,7 @@ llvm.mlir.global linkonce @take_self_address() : !llvm.struct<(i32, !llvm.ptr<i3
   %z32 = llvm.mlir.constant(0 : i32) : i32
   %0 = llvm.mlir.undef : !llvm.struct<(i32, !llvm.ptr<i32>)>
   %1 = llvm.mlir.addressof @take_self_address : !llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>
-  %2 = llvm.getelementptr %1[%z32, %z32] : (!llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>, i32, i32) -> !llvm.ptr<i32>
+  %2 = llvm.getelementptr %1[%z32, 0] : (!llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>, i32) -> !llvm.ptr<i32>
   %3 = llvm.insertvalue %z32, %0[0 : i32] : !llvm.struct<(i32, !llvm.ptr<i32>)>
   %4 = llvm.insertvalue %2, %3[1 : i32] : !llvm.struct<(i32, !llvm.ptr<i32>)>
   llvm.return %4 : !llvm.struct<(i32, !llvm.ptr<i32>)>