From cafaa3503643c047b670dcc387e22096930d2d6c Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 6 Jan 2022 23:29:15 +0100 Subject: [PATCH] [mlir] Make it possible to directly supply constant values to LLVM GEPOp In LLVM IR, the GEP indices that correspond to structures are required to be i32 constants. MLIR models constants as just values defined by special operations, and there is no verification that it is the case for structure indices in GEP. Furthermore, some common transformations such as control flow simplification may lead to the operands becoming non-constant. Make it possible to directly supply constant values to LLVM GEPOp to guarantee they remain constant until the translation to LLVM IR. This is not yet a requirement and the verifier is not modified, this will be introduced separately. Reviewed By: wsmoses Differential Revision: https://reviews.llvm.org/D116757 --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 38 +++++++++++--- .../Conversion/GPUCommon/GPUToLLVMConversion.cpp | 4 +- mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 8 +-- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 5 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 61 ++++++++++++++++++++++ mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 3 +- mlir/test/Dialect/LLVMIR/roundtrip.mlir | 10 ++++ mlir/test/Target/LLVMIR/llvmir.mlir | 10 ++++ 8 files changed, 122 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 9555173..dd4def5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -315,17 +315,39 @@ def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase { let printer = [{ printAllocaOp(p, *this); }]; } -def LLVM_GEPOp - : LLVM_Op<"getelementptr", [NoSideEffect]>, - LLVM_Builder< - "$res = builder.CreateGEP(" - " $base->getType()->getPointerElementType(), $base, $indices);"> { +def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, - Variadic>:$indices); + Variadic>:$indices, + I32ElementsAttr:$structIndices); let results = (outs LLVM_ScalarOrVectorOf:$res); - let builders = [LLVM_OneResultOpBuilder]; + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices, + CArg<"ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices, + "ArrayRef":$structIndices, + CArg<"ArrayRef", "{}">:$attributes)>, + ]; + let llvmBuilder = [{ + SmallVector indices; + indices.reserve($structIndices.size()); + unsigned operandIdx = 0; + for (int32_t structIndex : $structIndices.getValues()) { + if (structIndex == GEPOp::kDynamicIndex) + indices.push_back($indices[operandIdx++]); + else + indices.push_back(builder.getInt32(structIndex)); + } + $res = builder.CreateGEP( + $base->getType()->getPointerElementType(), $base, indices); + }]; let assemblyFormat = [{ - $base `[` $indices `]` attr-dict `:` functional-type(operands, results) + $base `[` custom($indices, $structIndices) `]` attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + constexpr static int kDynamicIndex = std::numeric_limits::min(); }]; let hasFolder = 1; } diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index f7f8b6b..ab498da 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -790,8 +790,8 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( Type elementPtrType = getElementPtrType(memRefType); Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create( - loc, elementPtrType, ArrayRef{nullPtr, numElements}); + Value gepPtr = rewriter.create(loc, elementPtrType, nullPtr, + ArrayRef{numElements}); auto sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 0003bd8..41e8cef 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -162,8 +162,8 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes( // Buffer size in bytes. Type elementPtrType = getElementPtrType(memRefType); Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create( - loc, elementPtrType, ArrayRef{nullPtr, runningStride}); + Value gepPtr = rewriter.create(loc, elementPtrType, nullPtr, + ArrayRef{runningStride}); sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); } @@ -178,8 +178,8 @@ Value ConvertToLLVMPattern::getSizeInBytes( LLVM::LLVMPointerType::get(typeConverter->convertType(type)); auto nullPtr = rewriter.create(loc, convertedPtrType); auto gep = rewriter.create( - loc, convertedPtrType, - ArrayRef{nullPtr, createIndexConstant(rewriter, loc, 1)}); + loc, convertedPtrType, nullPtr, + ArrayRef{createIndexConstant(rewriter, loc, 1)}); return rewriter.create(loc, getIndexType(), gep); } diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 9142be1..d8fa965 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -497,10 +497,11 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { Type elementType = typeConverter->convertType(type.getElementType()); Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); - SmallVector operands = {addressOf}; + SmallVector operands; operands.insert(operands.end(), type.getRank() + 1, createIndexConstant(rewriter, loc, 0)); - auto gep = rewriter.create(loc, elementPtrType, operands); + auto gep = + rewriter.create(loc, elementPtrType, addressOf, operands); // We do not expect the memref obtained using `memref.get_global` to be // ever deallocated. Set the allocated pointer to be known bad value to diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index e65c14e..995d2ea 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -357,6 +357,67 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) { } //===----------------------------------------------------------------------===// +// Code for LLVM::GEPOp. +//===----------------------------------------------------------------------===// + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Value basePtr, ValueRange operands, + ArrayRef attributes) { + build(builder, result, resultType, basePtr, operands, + SmallVector(operands.size(), LLVM::GEPOp::kDynamicIndex), + attributes); +} + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Value basePtr, ValueRange indices, + ArrayRef structIndices, + ArrayRef attributes) { + result.addTypes(resultType); + result.addAttributes(attributes); + result.addAttribute("structIndices", builder.getI32TensorAttr(structIndices)); + result.addOperands(basePtr); + result.addOperands(indices); +} + +static ParseResult +parseGEPIndices(OpAsmParser &parser, + SmallVectorImpl &indices, + DenseIntElementsAttr &structIndices) { + SmallVector constantIndices; + do { + int32_t constantIndex; + OptionalParseResult parsedInteger = + parser.parseOptionalInteger(constantIndex); + if (parsedInteger.hasValue()) { + if (failed(parsedInteger.getValue())) + return failure(); + constantIndices.push_back(constantIndex); + continue; + } + + constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); + if (failed(parser.parseOperand(indices.emplace_back()))) + return failure(); + } while (succeeded(parser.parseOptionalComma())); + + structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); + return success(); +} + +static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, + OperandRange indices, + DenseIntElementsAttr structIndices) { + unsigned operandIdx = 0; + llvm::interleaveComma(structIndices.getValues(), printer, + [&](int32_t cst) { + if (cst == LLVM::GEPOp::kDynamicIndex) + printer.printOperand(indices[operandIdx++]); + else + printer << cst; + }); +} + +//===----------------------------------------------------------------------===// // Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 3e06f9c..f7ad338 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -760,7 +760,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { Type type = processType(inst->getType()); if (!type) return failure(); - v = b.create(loc, type, ops); + v = b.create(loc, type, ops[0], + llvm::makeArrayRef(ops).drop_front()); return success(); } } diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index b6a09d6..d4172e1 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -170,6 +170,16 @@ func @ops(%arg0: i32, %arg1: f32, llvm.return } +// CHECK-LABEL: @gep +llvm.func @gep(%ptr: !llvm.ptr)>>, %idx: i64, + %ptr2: !llvm.ptr)>>) { + // CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 1, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr + llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr + // CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr)>>, i64, i64) -> !llvm.ptr + llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr)>>, i64, i64) -> !llvm.ptr + llvm.return +} + // An larger self-contained function. // CHECK-LABEL: llvm.func @foo(%{{.*}}: i32) -> !llvm.struct<(i32, f64, i32)> { llvm.func @foo(%arg0: i32) -> !llvm.struct<(i32, f64, i32)> { diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 54dfd51..04a65f8 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -975,6 +975,16 @@ llvm.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32) -> !llvm.struct<( llvm.return %10 : !llvm.struct<(f32, i32)> } +// CHECK-LABEL: @gep +llvm.func @gep(%ptr: !llvm.ptr)>>, %idx: i64, + %ptr2: !llvm.ptr)>>) { + // CHECK: = getelementptr { i32, { i32, float } }, { i32, { i32, float } }* %{{.*}}, i64 %{{.*}}, i32 1, i32 0 + llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr + // CHECK: = getelementptr { [10 x float] }, { [10 x float] }* %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}} + llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr)>>, i64, i64) -> !llvm.ptr + llvm.return +} + // // Indirect function calls // -- 2.7.4