[mlir] Make it possible to directly supply constant values to LLVM GEPOp
authorAlex Zinenko <zinenko@google.com>
Thu, 6 Jan 2022 22:29:15 +0000 (23:29 +0100)
committerAlex Zinenko <zinenko@google.com>
Fri, 7 Jan 2022 08:56:01 +0000 (09:56 +0100)
In LLVM IR, the GEP indices that correspond to structures are required to be
i32 constants. MLIR models constants as just values defined by special
operations, and there is no verification that it is the case for structure
indices in GEP. Furthermore, some common transformations such as control flow
simplification may lead to the operands becoming non-constant. Make it possible
to directly supply constant values to LLVM GEPOp to guarantee they remain
constant until the translation to LLVM IR. This is not yet a requirement and
the verifier is not modified, this will be introduced separately.

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D116757

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/Target/LLVMIR/llvmir.mlir

index 9555173..dd4def5 100644 (file)
@@ -315,17 +315,39 @@ def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase {
   let printer = [{ printAllocaOp(p, *this); }];
 }
 
-def LLVM_GEPOp
-    : LLVM_Op<"getelementptr", [NoSideEffect]>,
-      LLVM_Builder<
-          "$res = builder.CreateGEP("
-          " $base->getType()->getPointerElementType(), $base, $indices);"> {
+def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
   let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
-                   Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$indices);
+                   Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$indices,
+                   I32ElementsAttr:$structIndices);
   let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
-  let builders = [LLVM_OneResultOpBuilder];
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices,
+               CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+    OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices,
+               "ArrayRef<int32_t>":$structIndices,
+               CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+  ];
+  let llvmBuilder = [{
+    SmallVector<llvm::Value *> indices;
+    indices.reserve($structIndices.size());
+    unsigned operandIdx = 0;
+    for (int32_t structIndex : $structIndices.getValues<int32_t>()) {
+      if (structIndex == GEPOp::kDynamicIndex)
+        indices.push_back($indices[operandIdx++]);
+      else
+        indices.push_back(builder.getInt32(structIndex));
+    }
+    $res = builder.CreateGEP(
+      $base->getType()->getPointerElementType(), $base, indices);
+  }];
   let assemblyFormat = [{
-    $base `[` $indices `]` attr-dict `:` functional-type(operands, results)
+    $base `[` custom<GEPIndices>($indices, $structIndices) `]` attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
   }];
   let hasFolder = 1;
 }
index f7f8b6b..ab498da 100644 (file)
@@ -790,8 +790,8 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
 
   Type elementPtrType = getElementPtrType(memRefType);
   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
-  Value gepPtr = rewriter.create<LLVM::GEPOp>(
-      loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements});
+  Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
+                                              ArrayRef<Value>{numElements});
   auto sizeBytes =
       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
 
index 0003bd8..41e8cef 100644 (file)
@@ -162,8 +162,8 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
   // Buffer size in bytes.
   Type elementPtrType = getElementPtrType(memRefType);
   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
-  Value gepPtr = rewriter.create<LLVM::GEPOp>(
-      loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride});
+  Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
+                                              ArrayRef<Value>{runningStride});
   sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
 }
 
@@ -178,8 +178,8 @@ Value ConvertToLLVMPattern::getSizeInBytes(
       LLVM::LLVMPointerType::get(typeConverter->convertType(type));
   auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
   auto gep = rewriter.create<LLVM::GEPOp>(
-      loc, convertedPtrType,
-      ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
+      loc, convertedPtrType, nullPtr,
+      ArrayRef<Value>{createIndexConstant(rewriter, loc, 1)});
   return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
 }
 
index 9142be1..d8fa965 100644 (file)
@@ -497,10 +497,11 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
     Type elementType = typeConverter->convertType(type.getElementType());
     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
 
-    SmallVector<Value, 4> operands = {addressOf};
+    SmallVector<Value> operands;
     operands.insert(operands.end(), type.getRank() + 1,
                     createIndexConstant(rewriter, loc, 0));
-    auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
+    auto gep =
+        rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
 
     // We do not expect the memref obtained using `memref.get_global` to be
     // ever deallocated. Set the allocated pointer to be known bad value to
index e65c14e..995d2ea 100644 (file)
@@ -357,6 +357,67 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
 }
 
 //===----------------------------------------------------------------------===//
+// Code for LLVM::GEPOp.
+//===----------------------------------------------------------------------===//
+
+void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+                  Value basePtr, ValueRange operands,
+                  ArrayRef<NamedAttribute> attributes) {
+  build(builder, result, resultType, basePtr, operands,
+        SmallVector<int32_t>(operands.size(), LLVM::GEPOp::kDynamicIndex),
+        attributes);
+}
+
+void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+                  Value basePtr, ValueRange indices,
+                  ArrayRef<int32_t> structIndices,
+                  ArrayRef<NamedAttribute> attributes) {
+  result.addTypes(resultType);
+  result.addAttributes(attributes);
+  result.addAttribute("structIndices", builder.getI32TensorAttr(structIndices));
+  result.addOperands(basePtr);
+  result.addOperands(indices);
+}
+
+static ParseResult
+parseGEPIndices(OpAsmParser &parser,
+                SmallVectorImpl<OpAsmParser::OperandType> &indices,
+                DenseIntElementsAttr &structIndices) {
+  SmallVector<int32_t> constantIndices;
+  do {
+    int32_t constantIndex;
+    OptionalParseResult parsedInteger =
+        parser.parseOptionalInteger(constantIndex);
+    if (parsedInteger.hasValue()) {
+      if (failed(parsedInteger.getValue()))
+        return failure();
+      constantIndices.push_back(constantIndex);
+      continue;
+    }
+
+    constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
+    if (failed(parser.parseOperand(indices.emplace_back())))
+      return failure();
+  } while (succeeded(parser.parseOptionalComma()));
+
+  structIndices = parser.getBuilder().getI32TensorAttr(constantIndices);
+  return success();
+}
+
+static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
+                            OperandRange indices,
+                            DenseIntElementsAttr structIndices) {
+  unsigned operandIdx = 0;
+  llvm::interleaveComma(structIndices.getValues<int32_t>(), printer,
+                        [&](int32_t cst) {
+                          if (cst == LLVM::GEPOp::kDynamicIndex)
+                            printer.printOperand(indices[operandIdx++]);
+                          else
+                            printer << cst;
+                        });
+}
+
+//===----------------------------------------------------------------------===//
 // Builder, printer and parser for for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//
 
index 3e06f9c..f7ad338 100644 (file)
@@ -760,7 +760,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
     Type type = processType(inst->getType());
     if (!type)
       return failure();
-    v = b.create<GEPOp>(loc, type, ops);
+    v = b.create<GEPOp>(loc, type, ops[0],
+                        llvm::makeArrayRef(ops).drop_front());
     return success();
   }
   }
index b6a09d6..d4172e1 100644 (file)
@@ -170,6 +170,16 @@ func @ops(%arg0: i32, %arg1: f32,
   llvm.return
 }
 
+// CHECK-LABEL: @gep
+llvm.func @gep(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx: i64,
+               %ptr2: !llvm.ptr<struct<(array<10xf32>)>>) {
+  // CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 1, 0] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
+  llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
+  // CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr<struct<(array<10 x f32>)>>, i64, i64) -> !llvm.ptr<f32>
+  llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr<struct<(array<10 x f32>)>>, i64, i64) -> !llvm.ptr<f32>
+  llvm.return
+}
+
 // An larger self-contained function.
 // CHECK-LABEL: llvm.func @foo(%{{.*}}: i32) -> !llvm.struct<(i32, f64, i32)> {
 llvm.func @foo(%arg0: i32) -> !llvm.struct<(i32, f64, i32)> {
index 54dfd51..04a65f8 100644 (file)
@@ -975,6 +975,16 @@ llvm.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32) -> !llvm.struct<(
   llvm.return %10 : !llvm.struct<(f32, i32)>
 }
 
+// CHECK-LABEL: @gep
+llvm.func @gep(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx: i64,
+               %ptr2: !llvm.ptr<struct<(array<10xf32>)>>) {
+  // CHECK: = getelementptr { i32, { i32, float } }, { i32, { i32, float } }* %{{.*}}, i64 %{{.*}}, i32 1, i32 0
+  llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
+  // CHECK: = getelementptr { [10 x float] }, { [10 x float] }* %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
+  llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr<struct<(array<10xf32>)>>, i64, i64) -> !llvm.ptr<f32>
+  llvm.return
+}
+
 //
 // Indirect function calls
 //