From 162f7572067d7d2d70202f5ff42532adf6f75517 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Thu, 6 Apr 2023 16:10:45 +0000 Subject: [PATCH] [mlir][LLVM] Add an attribute to control use of bare-pointer calling convention. Currently the use of bare pointer calling convention is controlled globally through use of an option in the `LLVMTypeConverter`. To allow more fine-grained control use an attribute on a function to drive the calling convention to use. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D147494 --- .../mlir/Conversion/LLVMCommon/TypeConverter.h | 11 +- mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 206 +++++++++++---------- mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 3 +- mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp | 46 +++-- mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 12 +- .../Conversion/FuncToLLVM/calling-convention.mlir | 64 +++++++ .../Conversion/FuncToLLVM/func-memref-return.mlir | 6 +- 7 files changed, 218 insertions(+), 130 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index b13b88d..6005751 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -53,20 +53,23 @@ public: /// one and results are packed into a wrapped LLVM IR structure type. `result` /// is populated with argument mapping. Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, + bool useBarePtrCallConv, SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a /// supported LLVM IR type. In particular, if more than one value is /// returned, create an LLVM IR structure type with elements that correspond /// to each of the MLIR types converted with `convertType`. - Type packFunctionResults(TypeRange types); + Type packFunctionResults(TypeRange types, + bool useBarePointerCallConv = false); /// Convert a type in the context of the default or bare pointer calling /// convention. Calling convention sensitive types, such as MemRefType and /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. - Type convertCallingConventionType(Type type); + Type convertCallingConventionType(Type type, + bool useBarePointerCallConv = false); /// Promote the bare pointers in 'values' that resulted from memrefs to /// descriptors. 'stdTypes' holds the types of 'values' before the conversion @@ -95,8 +98,8 @@ public: /// of the platform-specific C/C++ ABI lowering related to struct argument /// passing. SmallVector promoteOperands(Location loc, ValueRange opOperands, - ValueRange operands, - OpBuilder &builder); + ValueRange operands, OpBuilder &builder, + bool useBarePtrCallConv = false); /// Promote the LLVM struct representation of one MemRef descriptor to stack /// and use pointer to struct to avoid the complexity of the platform-specific diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 7200b2b..86394aa 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -58,6 +58,14 @@ using namespace mlir; static constexpr StringRef varargsAttrName = "func.varargs"; static constexpr StringRef linkageAttrName = "llvm.linkage"; +static constexpr StringRef barePtrAttrName = "llvm.bareptr"; + +/// Return `true` if the `op` should use bare pointer calling convention. +static bool shouldUseBarePtrCallConv(Operation *op, + LLVMTypeConverter *typeConverter) { + return (op && op->hasAttr(barePtrAttrName)) || + typeConverter->getOptions().useBarePtrCallConv; +} /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument @@ -267,6 +275,55 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, } } +/// Modifies the body of the function to construct the `MemRefDescriptor` from +/// the bare pointer calling convention lowering of `memref` types. +static void modifyFuncOpToUseBarePtrCallingConv( + ConversionPatternRewriter &rewriter, Location loc, + LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp, + TypeRange oldArgTypes) { + if (funcOp.getBody().empty()) + return; + + // Promote bare pointers from memref arguments to memref descriptors at the + // beginning of the function so that all the memrefs in the function have a + // uniform representation. + Block *entryBlock = &funcOp.getBody().front(); + auto blockArgs = entryBlock->getArguments(); + assert(blockArgs.size() == oldArgTypes.size() && + "The number of arguments and types doesn't match"); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + for (auto it : llvm::zip(blockArgs, oldArgTypes)) { + BlockArgument arg = std::get<0>(it); + Type argTy = std::get<1>(it); + + // Unranked memrefs are not supported in the bare pointer calling + // convention. We should have bailed out before in the presence of + // unranked memrefs. + assert(!argTy.isa() && + "Unranked memref is not supported"); + auto memrefTy = argTy.dyn_cast(); + if (!memrefTy) + continue; + + // Replace barePtr with a placeholder (undef), promote barePtr to a ranked + // or unranked memref descriptor and replace placeholder with the last + // instruction of the memref descriptor. + // TODO: The placeholder is needed to avoid replacing barePtr uses in the + // MemRef descriptor instructions. We may want to have a utility in the + // rewriter to properly handle this use case. + Location loc = funcOp.getLoc(); + auto placeholder = rewriter.create( + loc, typeConverter.convertType(memrefTy)); + rewriter.replaceUsesOfBlockArgument(arg, placeholder); + + Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter, + memrefTy, arg); + rewriter.replaceOp(placeholder, {desc}); + } +} + namespace { struct FuncOpConversionBase : public ConvertOpToLLVMPattern { @@ -284,7 +341,7 @@ protected: TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = getTypeConverter()->convertFunctionSignature( funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(), - result); + shouldUseBarePtrCallConv(funcOp, getTypeConverter()), result); if (!llvmType) return nullptr; @@ -415,89 +472,24 @@ struct FuncOpConversion : public FuncOpConversionBase { if (!newFuncOp) return failure(); - if (funcOp->getAttrOfType( - LLVM::LLVMDialect::getEmitCWrapperAttrName())) { - if (newFuncOp.isVarArg()) - return funcOp->emitError("C interface for variadic functions is not " - "supported yet."); - - if (newFuncOp.isExternal()) - wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(), - funcOp, newFuncOp); - else - wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(), - funcOp, newFuncOp); - } - - rewriter.eraseOp(funcOp); - return success(); - } -}; - -/// FuncOp legalization pattern that converts MemRef arguments to bare pointers -/// to the MemRef element type. This will impact the calling convention and ABI. -struct BarePtrFuncOpConversion : public FuncOpConversionBase { - using FuncOpConversionBase::FuncOpConversionBase; - - LogicalResult - matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - // TODO: bare ptr conversion could be handled by argument materialization - // and most of the code below would go away. But to do this, we would need a - // way to distinguish between FuncOp and other regions in the - // addArgumentMaterialization hook. + if (!shouldUseBarePtrCallConv(funcOp, this->getTypeConverter())) { + if (funcOp->getAttrOfType( + LLVM::LLVMDialect::getEmitCWrapperAttrName())) { + if (newFuncOp.isVarArg()) + return funcOp->emitError("C interface for variadic functions is not " + "supported yet."); - // Store the type of memref-typed arguments before the conversion so that we - // can promote them to MemRef descriptor at the beginning of the function. - SmallVector oldArgTypes = - llvm::to_vector<8>(funcOp.getFunctionType().getInputs()); - - auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); - if (!newFuncOp) - return failure(); - if (newFuncOp.getBody().empty()) { - rewriter.eraseOp(funcOp); - return success(); - } - - // Promote bare pointers from memref arguments to memref descriptors at the - // beginning of the function so that all the memrefs in the function have a - // uniform representation. - Block *entryBlock = &newFuncOp.getBody().front(); - auto blockArgs = entryBlock->getArguments(); - assert(blockArgs.size() == oldArgTypes.size() && - "The number of arguments and types doesn't match"); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(entryBlock); - for (auto it : llvm::zip(blockArgs, oldArgTypes)) { - BlockArgument arg = std::get<0>(it); - Type argTy = std::get<1>(it); - - // Unranked memrefs are not supported in the bare pointer calling - // convention. We should have bailed out before in the presence of - // unranked memrefs. - assert(!argTy.isa() && - "Unranked memref is not supported"); - auto memrefTy = argTy.dyn_cast(); - if (!memrefTy) - continue; - - // Replace barePtr with a placeholder (undef), promote barePtr to a ranked - // or unranked memref descriptor and replace placeholder with the last - // instruction of the memref descriptor. - // TODO: The placeholder is needed to avoid replacing barePtr uses in the - // MemRef descriptor instructions. We may want to have a utility in the - // rewriter to properly handle this use case. - Location loc = funcOp.getLoc(); - auto placeholder = rewriter.create( - loc, getTypeConverter()->convertType(memrefTy)); - rewriter.replaceUsesOfBlockArgument(arg, placeholder); - - Value desc = MemRefDescriptor::fromStaticShape( - rewriter, loc, *getTypeConverter(), memrefTy, arg); - rewriter.replaceOp(placeholder, {desc}); + if (newFuncOp.isExternal()) + wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(), + funcOp, newFuncOp); + else + wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(), + funcOp, newFuncOp); + } + } else { + modifyFuncOpToUseBarePtrCallingConv(rewriter, funcOp.getLoc(), + *getTypeConverter(), newFuncOp, + funcOp.getFunctionType().getInputs()); } rewriter.eraseOp(funcOp); @@ -535,23 +527,24 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { using Super = CallOpInterfaceLowering; using Base = ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult matchAndRewriteImpl(CallOpType callOp, + typename CallOpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter, + bool useBarePtrCallConv = false) const { // Pack the result types into a struct. Type packedResult = nullptr; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); if (numResults != 0) { - if (!(packedResult = - this->getTypeConverter()->packFunctionResults(resultTypes))) + if (!(packedResult = this->getTypeConverter()->packFunctionResults( + resultTypes, useBarePtrCallConv))) return failure(); } auto promoted = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), - adaptor.getOperands(), rewriter); + adaptor.getOperands(), rewriter, useBarePtrCallConv); auto newOp = rewriter.create( callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), promoted, callOp->getAttrs()); @@ -570,7 +563,7 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { } } - if (this->getTypeConverter()->getOptions().useBarePtrCallConv) { + if (useBarePtrCallConv) { // For the bare-ptr calling convention, promote memref results to // descriptors. assert(results.size() == resultTypes.size() && @@ -590,11 +583,28 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { struct CallOpLowering : public CallOpInterfaceLowering { using Super::Super; + + LogicalResult + matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + bool useBarePtrCallConv = false; + if (Operation *callee = SymbolTable::lookupNearestSymbolFrom( + callOp, callOp.getCalleeAttr())) { + useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter()); + } + return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv); + } }; struct CallIndirectOpLowering : public CallOpInterfaceLowering { using Super::Super; + + LogicalResult + matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter); + } }; struct UnrealizedConversionCastOpLowering @@ -640,7 +650,10 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { unsigned numArguments = op.getNumOperands(); SmallVector updatedOperands; - if (getTypeConverter()->getOptions().useBarePtrCallConv) { + auto funcOp = op->getParentOfType(); + bool useBarePtrCallConv = + shouldUseBarePtrCallConv(funcOp, this->getTypeConverter()); + if (useBarePtrCallConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { @@ -649,7 +662,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { if (oldTy.isa() && getTypeConverter()->canConvertToBarePtr( oldTy.cast())) { MemRefDescriptor memrefDesc(newOperand); - newOperand = memrefDesc.alignedPtr(rewriter, loc); + newOperand = memrefDesc.allocatedPtr(rewriter, loc); } else if (oldTy.isa()) { // Unranked memref is not supported in the bare pointer calling // convention. @@ -673,8 +686,8 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. - auto packedType = - getTypeConverter()->packFunctionResults(op.getOperandTypes()); + auto packedType = getTypeConverter()->packFunctionResults( + op.getOperandTypes(), useBarePtrCallConv); if (!packedType) { return rewriter.notifyMatchFailure(op, "could not convert result types"); } @@ -692,10 +705,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { void mlir::populateFuncToLLVMFuncOpConversionPattern( LLVMTypeConverter &converter, RewritePatternSet &patterns) { - if (converter.getOptions().useBarePtrCallConv) - patterns.add(converter); - else - patterns.add(converter); + patterns.add(converter); } void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index ec0d240..82c73b5 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -47,7 +47,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, TypeConverter::SignatureConversion signatureConversion( gpuFuncOp.front().getNumArguments()); Type funcType = getTypeConverter()->convertFunctionSignature( - gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion); + gpuFuncOp.getFunctionType(), /*isVariadic=*/false, + getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion); // Create the new function operation. Only copy those attributes that are // not specific to function modeling. diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index e24be1d..833ea36 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -209,8 +209,8 @@ Type LLVMTypeConverter::convertComplexType(ComplexType type) { // pointer-to-function types. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { SignatureConversion conversion(type.getNumInputs()); - Type converted = - convertFunctionSignature(type, /*isVariadic=*/false, conversion); + Type converted = convertFunctionSignature( + type, /*isVariadic=*/false, options.useBarePtrCallConv, conversion); if (!converted) return {}; return getPointerType(converted); @@ -221,12 +221,12 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) { // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. Type LLVMTypeConverter::convertFunctionSignature( - FunctionType funcTy, bool isVariadic, + FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, LLVMTypeConverter::SignatureConversion &result) { // Select the argument converter depending on the calling convention. - auto funcArgConverter = options.useBarePtrCallConv - ? barePtrFuncArgTypeConverter - : structFuncArgTypeConverter; + useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv; + auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter + : structFuncArgTypeConverter; // Convert argument types one by one and check for errors. for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) { SmallVector converted; @@ -238,9 +238,10 @@ Type LLVMTypeConverter::convertFunctionSignature( // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. - Type resultType = funcTy.getNumResults() == 0 - ? LLVM::LLVMVoidType::get(&getContext()) - : packFunctionResults(funcTy.getResults()); + Type resultType = + funcTy.getNumResults() == 0 + ? LLVM::LLVMVoidType::get(&getContext()) + : packFunctionResults(funcTy.getResults(), useBarePtrCallConv); if (!resultType) return {}; return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(), @@ -472,8 +473,9 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) { /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. -Type LLVMTypeConverter::convertCallingConventionType(Type type) { - if (options.useBarePtrCallConv) +Type LLVMTypeConverter::convertCallingConventionType(Type type, + bool useBarePtrCallConv) { + if (useBarePtrCallConv) if (auto memrefTy = type.dyn_cast()) return convertMemRefToBarePtr(memrefTy); @@ -498,16 +500,18 @@ void LLVMTypeConverter::promoteBarePtrsToDescriptors( /// supported LLVM IR type. In particular, if more than one value is returned, /// create an LLVM IR structure type with elements that correspond to each of /// the MLIR types converted with `convertType`. -Type LLVMTypeConverter::packFunctionResults(TypeRange types) { +Type LLVMTypeConverter::packFunctionResults(TypeRange types, + bool useBarePtrCallConv) { assert(!types.empty() && "expected non-empty list of type"); + useBarePtrCallConv |= options.useBarePtrCallConv; if (types.size() == 1) - return convertCallingConventionType(types.front()); + return convertCallingConventionType(types.front(), useBarePtrCallConv); SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { - auto converted = convertCallingConventionType(t); + auto converted = convertCallingConventionType(t, useBarePtrCallConv); if (!converted || !LLVM::isCompatibleType(converted)) return {}; resultTypes.push_back(converted); @@ -530,17 +534,18 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, return allocated; } -SmallVector LLVMTypeConverter::promoteOperands(Location loc, - ValueRange opOperands, - ValueRange operands, - OpBuilder &builder) { +SmallVector +LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, + ValueRange operands, OpBuilder &builder, + bool useBarePtrCallConv) { SmallVector promotedOperands; promotedOperands.reserve(operands.size()); + useBarePtrCallConv |= options.useBarePtrCallConv; for (auto it : llvm::zip(opOperands, operands)) { auto operand = std::get<0>(it); auto llvmOperand = std::get<1>(it); - if (options.useBarePtrCallConv) { + if (useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. if (auto memrefType = operand.getType().dyn_cast()) { @@ -603,7 +608,8 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { - auto llvmTy = converter.convertCallingConventionType(type); + auto llvmTy = + converter.convertCallingConventionType(type, /*useBarePtrCallConv=*/true); if (!llvmTy) return failure(); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 2cdce91..b938947 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -338,7 +338,8 @@ public: auto dstType = typeConverter.convertType(op.getPointer().getType()); if (!dstType) return failure(); - rewriter.replaceOpWithNewOp(op, dstType, op.getVariable()); + rewriter.replaceOpWithNewOp(op, dstType, + op.getVariable()); return success(); } }; @@ -582,7 +583,8 @@ public: } rewriter.replaceOpWithNewOp( - op, adaptor.getComposite(), LLVM::convertArrayToIndices(op.getIndices())); + op, adaptor.getComposite(), + LLVM::convertArrayToIndices(op.getIndices())); return success(); } }; @@ -1146,7 +1148,8 @@ public: Block *falseBlock = condBrOp.getFalseBlock(); rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, condBrOp.getCondition(), trueBlock, - condBrOp.getTrueTargetOperands(), falseBlock, + condBrOp.getTrueTargetOperands(), + falseBlock, condBrOp.getFalseTargetOperands()); rewriter.inlineRegionBefore(op.getBody(), continueBlock); @@ -1329,7 +1332,8 @@ public: TypeConverter::SignatureConversion signatureConverter( funcType.getNumInputs()); auto llvmType = typeConverter.convertFunctionSignature( - funcType, /*isVariadic=*/false, signatureConverter); + funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false, + signatureConverter); if (!llvmType) return failure(); diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir index daa824d..b1c065e 100644 --- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir @@ -242,3 +242,67 @@ func.func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memr // CHECK-LABEL: @_mlir_ciface_return_two_var_memref // CHECK-SAME: (%{{.*}}: !llvm.ptr, // CHECK-SAME: %{{.*}}: !llvm.ptr) + +// CHECK-LABEL: llvm.func @bare_ptr_calling_conv( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr +// CHECK-SAME: -> !llvm.ptr +func.func @bare_ptr_calling_conv(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 : index, %arg3 : f32) + -> (memref<4x3xf32>) attributes { llvm.bareptr } { + // CHECK: %[[UNDEF_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[INSERT_ALLOCPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[UNDEF_DESC]][0] + // CHECK: %[[INSERT_ALIGNEDPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_ALLOCPTR]][1] + // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: %[[INSERT_OFFSET:.*]] = llvm.insertvalue %[[C0]], %[[INSERT_ALIGNEDPTR]][2] + // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64 + // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertvalue %[[C4]], %[[INSERT_OFFSET]][3, 0] + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64 + // CHECK: %[[INSERT_STRIDE0:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_DIM0]][4, 0] + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64 + // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_STRIDE0]][3, 1] + // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1] + + // CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1] + // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]] + // CHECK: llvm.store %{{.*}}, %[[STOREPTR]] + memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32> + + // CHECK: llvm.return %[[ARG0]] + return %arg0 : memref<4x3xf32> +} + +// CHECK-LABEL: llvm.func @bare_ptr_calling_conv_multiresult( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr +// CHECK-SAME: -> !llvm.struct<(f32, ptr)> +func.func @bare_ptr_calling_conv_multiresult(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 : index, %arg3 : f32) + -> (f32, memref<4x3xf32>) attributes { llvm.bareptr } { + // CHECK: %[[UNDEF_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[INSERT_ALLOCPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[UNDEF_DESC]][0] + // CHECK: %[[INSERT_ALIGNEDPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_ALLOCPTR]][1] + // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: %[[INSERT_OFFSET:.*]] = llvm.insertvalue %[[C0]], %[[INSERT_ALIGNEDPTR]][2] + // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64 + // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertvalue %[[C4]], %[[INSERT_OFFSET]][3, 0] + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64 + // CHECK: %[[INSERT_STRIDE0:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_DIM0]][4, 0] + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64 + // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_STRIDE0]][3, 1] + // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1] + + // CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1] + // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]] + // CHECK: llvm.store %{{.*}}, %[[STOREPTR]] + memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32> + + // CHECK: %[[ALIGNEDPTR0:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1] + // CHECK: %[[LOADPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR0]] + // CHECK: %[[RETURN0:.*]] = llvm.load %[[LOADPTR]] + %0 = memref.load %arg0[%arg1, %arg2] : memref<4x3xf32> + + // CHECK: %[[RETURN_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(f32, ptr)> + // CHECK: %[[INSERT_RETURN0:.*]] = llvm.insertvalue %[[RETURN0]], %[[RETURN_DESC]][0] + // CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_RETURN0]][1] + // CHECK: llvm.return %[[INSERT_RETURN1]] + return %0, %arg0 : f32, memref<4x3xf32> +} diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir index 8663ce8..956c2981 100644 --- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir +++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir @@ -27,7 +27,7 @@ func.func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : i64 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr return %static : memref<32x18xf32> } @@ -56,7 +56,7 @@ func.func @check_static_return_with_offset(%static : memref<32x18xf32, strided<[ // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : i64 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr return %static : memref<32x18xf32, strided<[22,1], offset: 7>> } @@ -82,7 +82,7 @@ func.func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> { // BAREPTR-NEXT: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64 // BAREPTR-NEXT: %[[outDesc:.*]] = llvm.insertvalue %[[c1]], %[[desc6]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %res = call @foo(%in) : (memref<10xi8>) -> (memref<20xi8>) - // BAREPTR-NEXT: %[[res:.*]] = llvm.extractvalue %[[outDesc]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // BAREPTR-NEXT: %[[res:.*]] = llvm.extractvalue %[[outDesc]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: llvm.return %[[res]] : !llvm.ptr return %res : memref<20xi8> } -- 2.7.4