From 5bf86d9e88fa841f5f50f4b8e3b337191691a45d Mon Sep 17 00:00:00 2001 From: Daniil Kovalev Date: Fri, 25 Mar 2022 12:36:20 +0300 Subject: [PATCH] [NVPTX] Remove code duplication in LowerCall In D120129 we enhanced vectorization options of byval parameters. This patch removes code duplication when handling byval and non-byval cases. Differential Revision: https://reviews.llvm.org/D122381 --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 258 +++++++++++----------------- 1 file changed, 98 insertions(+), 160 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 382e83d..11fc257 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1441,11 +1441,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, return Chain; unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1); - SDValue tempChain = Chain; + SDValue TempChain = Chain; Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl); SDValue InFlag = Chain.getValue(1); - unsigned paramCount = 0; + unsigned ParamCount = 0; // Args.size() and Outs.size() need not match. // Outs.size() will be larger // * if there is an aggregate argument with multiple fields (each field @@ -1461,185 +1461,115 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) { EVT VT = Outs[OIdx].VT; Type *Ty = Args[i].Ty; + bool IsByVal = Outs[OIdx].Flags.isByVal(); - if (!Outs[OIdx].Flags.isByVal()) { - SmallVector VTs; - SmallVector Offsets; - ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets); - Align ArgAlign = getArgumentAlignment(Callee, CB, Ty, paramCount + 1, DL); - unsigned AllocSize = DL.getTypeAllocSize(Ty); - SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - bool NeedAlign; // Does argument declaration specify alignment? - if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) { - // declare .param .align .b8 .param[]; - SDValue DeclareParamOps[] = { - Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32), - DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(AllocSize, dl, MVT::i32), InFlag}; - Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, - DeclareParamOps); - NeedAlign = true; - } else { - // declare .param .b .param; - if ((VT.isInteger() || VT.isFloatingPoint()) && AllocSize < 4) { - // PTX ABI requires integral types to be at least 32 bits in - // size. FP16 is loaded/stored using i16, so it's handled - // here as well. - AllocSize = 4; - } - SDValue DeclareScalarParamOps[] = { - Chain, DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(AllocSize * 8, dl, MVT::i32), - DAG.getConstant(0, dl, MVT::i32), InFlag}; - Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs, - DeclareScalarParamOps); - NeedAlign = false; - } - InFlag = Chain.getValue(1); - - // PTX Interoperability Guide 3.3(A): [Integer] Values shorter - // than 32-bits are sign extended or zero extended, depending on - // whether they are signed or unsigned types. This case applies - // only to scalar parameters and not to aggregate values. - bool ExtendIntegerParam = - Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32; - - auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); - SmallVector StoreOperands; - for (unsigned j = 0, je = VTs.size(); j != je; ++j) { - // New store. - if (VectorInfo[j] & PVF_FIRST) { - assert(StoreOperands.empty() && "Unfinished preceding store."); - StoreOperands.push_back(Chain); - StoreOperands.push_back(DAG.getConstant(paramCount, dl, MVT::i32)); - StoreOperands.push_back(DAG.getConstant(Offsets[j], dl, MVT::i32)); - } - - EVT EltVT = VTs[j]; - SDValue StVal = OutVals[OIdx]; - if (ExtendIntegerParam) { - assert(VTs.size() == 1 && "Scalar can't have multiple parts."); - // zext/sext to i32 - StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND - : ISD::ZERO_EXTEND, - dl, MVT::i32, StVal); - } else if (EltVT.getSizeInBits() < 16) { - // Use 16-bit registers for small stores as it's the - // smallest general purpose register size supported by NVPTX. - StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); - } - - // Record the value to store. - StoreOperands.push_back(StVal); - - if (VectorInfo[j] & PVF_LAST) { - unsigned NumElts = StoreOperands.size() - 3; - NVPTXISD::NodeType Op; - switch (NumElts) { - case 1: - Op = NVPTXISD::StoreParam; - break; - case 2: - Op = NVPTXISD::StoreParamV2; - break; - case 4: - Op = NVPTXISD::StoreParamV4; - break; - default: - llvm_unreachable("Invalid vector info."); - } - - StoreOperands.push_back(InFlag); - - // Adjust type of the store op if we've extended the scalar - // return value. - EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : VTs[j]; - MaybeAlign EltAlign; - if (NeedAlign) - EltAlign = commonAlignment(ArgAlign, Offsets[j]); - - Chain = DAG.getMemIntrinsicNode( - Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands, - TheStoreType, MachinePointerInfo(), EltAlign, - MachineMemOperand::MOStore); - InFlag = Chain.getValue(1); - - // Cleanup. - StoreOperands.clear(); - } - ++OIdx; - } - assert(StoreOperands.empty() && "Unfinished parameter store."); - if (VTs.size() > 0) - --OIdx; - ++paramCount; - continue; - } - - // ByVal arguments - // TODO: remove code duplication when handling byval and non-byval cases. SmallVector VTs; SmallVector Offsets; - Type *ETy = Args[i].IndirectType; - assert(ETy && "byval arg must have indirect type"); - ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, 0); - // declare .param .align .b8 .param[]; - unsigned sz = Outs[OIdx].Flags.getByValSize(); - SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + assert((!IsByVal || Args[i].IndirectType) && + "byval arg must have indirect type"); + Type *ETy = (IsByVal ? Args[i].IndirectType : Ty); + ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets); + + Align ArgAlign; + if (IsByVal) { + // The ByValAlign in the Outs[OIdx].Flags is always set at this point, + // so we don't need to worry whether it's naturally aligned or not. + // See TargetLowering::LowerCallTo(). + ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign(); + + // Try to increase alignment to enhance vectorization options. + ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign( + CB->getCalledFunction(), ETy, DL)); + + // Enforce minumum alignment of 4 to work around ptxas miscompile + // for sm_50+. See corresponding alignment adjustment in + // emitFunctionParamList() for details. + ArgAlign = std::max(ArgAlign, Align(4)); + } else { + ArgAlign = getArgumentAlignment(Callee, CB, Ty, ParamCount + 1, DL); + } - // The ByValAlign in the Outs[OIdx].Flags is alway set at this point, - // so we don't need to worry about natural alignment or not. - // See TargetLowering::LowerCallTo(). - Align ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign(); + unsigned TypeSize = + (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty)); + SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - // Try to increase alignment to enhance vectorization options. - const Function *F = CB->getCalledFunction(); - Align AlignCandidate = getFunctionParamOptimizedAlign(F, ETy, DL); - ArgAlign = std::max(ArgAlign, AlignCandidate); - - // Enforce minumum alignment of 4 to work around ptxas miscompile - // for sm_50+. See corresponding alignment adjustment in - // emitFunctionParamList() for details. - if (ArgAlign < Align(4)) - ArgAlign = Align(4); - SDValue DeclareParamOps[] = { - Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32), - DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(sz, dl, MVT::i32), InFlag}; - Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, - DeclareParamOps); + bool NeedAlign; // Does argument declaration specify alignment? + if (IsByVal || + (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128))) { + // declare .param .align .b8 .param[]; + SDValue DeclareParamOps[] = { + Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32), + DAG.getConstant(ParamCount, dl, MVT::i32), + DAG.getConstant(TypeSize, dl, MVT::i32), InFlag}; + Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, + DeclareParamOps); + NeedAlign = true; + } else { + // declare .param .b .param; + if ((VT.isInteger() || VT.isFloatingPoint()) && TypeSize < 4) { + // PTX ABI requires integral types to be at least 32 bits in + // size. FP16 is loaded/stored using i16, so it's handled + // here as well. + TypeSize = 4; + } + SDValue DeclareScalarParamOps[] = { + Chain, DAG.getConstant(ParamCount, dl, MVT::i32), + DAG.getConstant(TypeSize * 8, dl, MVT::i32), + DAG.getConstant(0, dl, MVT::i32), InFlag}; + Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs, + DeclareScalarParamOps); + NeedAlign = false; + } InFlag = Chain.getValue(1); + // PTX Interoperability Guide 3.3(A): [Integer] Values shorter + // than 32-bits are sign extended or zero extended, depending on + // whether they are signed or unsigned types. This case applies + // only to scalar parameters and not to aggregate values. + bool ExtendIntegerParam = + Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32; + auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); SmallVector StoreOperands; for (unsigned j = 0, je = VTs.size(); j != je; ++j) { - EVT elemtype = VTs[j]; - int curOffset = Offsets[j]; - Align PartAlign = commonAlignment(ArgAlign, curOffset); + EVT EltVT = VTs[j]; + int CurOffset = Offsets[j]; + MaybeAlign PartAlign; + if (NeedAlign) + PartAlign = commonAlignment(ArgAlign, CurOffset); // New store. if (VectorInfo[j] & PVF_FIRST) { assert(StoreOperands.empty() && "Unfinished preceding store."); StoreOperands.push_back(Chain); - StoreOperands.push_back(DAG.getConstant(paramCount, dl, MVT::i32)); - StoreOperands.push_back(DAG.getConstant(curOffset, dl, MVT::i32)); + StoreOperands.push_back(DAG.getConstant(ParamCount, dl, MVT::i32)); + StoreOperands.push_back(DAG.getConstant(CurOffset, dl, MVT::i32)); } - auto PtrVT = getPointerTy(DL); - SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, OutVals[OIdx], - DAG.getConstant(curOffset, dl, PtrVT)); - SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr, - MachinePointerInfo(), PartAlign); + SDValue StVal = OutVals[OIdx]; + if (IsByVal) { + auto PtrVT = getPointerTy(DL); + SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal, + DAG.getConstant(CurOffset, dl, PtrVT)); + StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(), + PartAlign); + } else if (ExtendIntegerParam) { + assert(VTs.size() == 1 && "Scalar can't have multiple parts."); + // zext/sext to i32 + StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND + : ISD::ZERO_EXTEND, + dl, MVT::i32, StVal); + } - if (elemtype.getSizeInBits() < 16) { + if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) { // Use 16-bit registers for small stores as it's the // smallest general purpose register size supported by NVPTX. - theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal); + StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); } // Record the value to store. - StoreOperands.push_back(theVal); + StoreOperands.push_back(StVal); if (VectorInfo[j] & PVF_LAST) { unsigned NumElts = StoreOperands.size() - 3; @@ -1660,18 +1590,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, StoreOperands.push_back(InFlag); + // Adjust type of the store op if we've extended the scalar + // return value. + EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT; + Chain = DAG.getMemIntrinsicNode( Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands, - elemtype, MachinePointerInfo(), PartAlign, + TheStoreType, MachinePointerInfo(), PartAlign, MachineMemOperand::MOStore); InFlag = Chain.getValue(1); // Cleanup. StoreOperands.clear(); } + if (!IsByVal) + ++OIdx; } assert(StoreOperands.empty() && "Unfinished parameter store."); - ++paramCount; + if (!IsByVal && VTs.size() > 0) + --OIdx; + ++ParamCount; } GlobalAddressSDNode *Func = dyn_cast(Callee.getNode()); @@ -1778,7 +1716,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, CallArgBeginOps); InFlag = Chain.getValue(1); - for (unsigned i = 0, e = paramCount; i != e; ++i) { + for (unsigned i = 0, e = ParamCount; i != e; ++i) { unsigned opcode; if (i == (e - 1)) opcode = NVPTXISD::LastCallArg; -- 2.7.4