[RISCV] Initial support for function calls
authorAlex Bradbury <asb@lowrisc.org>
Wed, 8 Nov 2017 13:41:21 +0000 (13:41 +0000)
committerAlex Bradbury <asb@lowrisc.org>
Wed, 8 Nov 2017 13:41:21 +0000 (13:41 +0000)
Note that this is just enough for simple function call examples to generate
working code. Support for varargs etc follows in future patches.

Differential Revision: https://reviews.llvm.org/D29936

llvm-svn: 317691

llvm/lib/Target/RISCV/RISCVFrameLowering.h
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
llvm/lib/Target/RISCV/RISCVInstrInfo.td
llvm/lib/Target/RISCV/RISCVMCInstLower.cpp
llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
llvm/lib/Target/RISCV/RISCVRegisterInfo.h
llvm/test/CodeGen/RISCV/calls.ll [new file with mode: 0644]

index 71f8586..0b2c7a4 100644 (file)
@@ -30,6 +30,12 @@ public:
   void emitEpilogue(MachineFunction &MF, MachineBasicBlock &MBB) const override;
 
   bool hasFP(const MachineFunction &MF) const override;
+
+  MachineBasicBlock::iterator
+  eliminateCallFramePseudoInstr(MachineFunction &MF, MachineBasicBlock &MBB,
+                                MachineBasicBlock::iterator MI) const override {
+    return MBB.erase(MI);
+  }
 };
 }
 #endif
index 9f19cda..98f7aa1 100644 (file)
@@ -108,6 +108,7 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
   default:
     report_fatal_error("Unsupported calling convention");
   case CallingConv::C:
+  case CallingConv::Fast:
     break;
   }
 
@@ -144,6 +145,135 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
   return Chain;
 }
 
+// Lower a call to a callseq_start + CALL + callseq_end chain, and add input
+// and output parameter nodes.
+SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
+                                       SmallVectorImpl<SDValue> &InVals) const {
+  SelectionDAG &DAG = CLI.DAG;
+  SDLoc &DL = CLI.DL;
+  SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
+  SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
+  SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
+  SDValue Chain = CLI.Chain;
+  SDValue Callee = CLI.Callee;
+  CLI.IsTailCall = false;
+  CallingConv::ID CallConv = CLI.CallConv;
+  bool IsVarArg = CLI.IsVarArg;
+  EVT PtrVT = getPointerTy(DAG.getDataLayout());
+
+  if (IsVarArg) {
+    report_fatal_error("LowerCall with varargs not implemented");
+  }
+
+  MachineFunction &MF = DAG.getMachineFunction();
+
+  // Analyze the operands of the call, assigning locations to each operand.
+  SmallVector<CCValAssign, 16> ArgLocs;
+  CCState ArgCCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
+  ArgCCInfo.AnalyzeCallOperands(Outs, CC_RISCV32);
+
+  // Get a count of how many bytes are to be pushed on the stack.
+  unsigned NumBytes = ArgCCInfo.getNextStackOffset();
+
+  for (auto &Arg : Outs) {
+    if (!Arg.Flags.isByVal())
+      continue;
+    report_fatal_error("Passing arguments byval not yet implemented");
+  }
+
+  Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL);
+
+  // Copy argument values to their designated locations.
+  SmallVector<std::pair<unsigned, SDValue>, 8> RegsToPass;
+  SDValue StackPtr;
+  for (unsigned I = 0, E = ArgLocs.size(); I != E; ++I) {
+    CCValAssign &VA = ArgLocs[I];
+    SDValue ArgValue = OutVals[I];
+
+    // Promote the value if needed.
+    // For now, only handle fully promoted arguments.
+    switch (VA.getLocInfo()) {
+    case CCValAssign::Full:
+      break;
+    default:
+      llvm_unreachable("Unknown loc info!");
+    }
+
+    if (VA.isRegLoc()) {
+      // Queue up the argument copies and emit them at the end.
+      RegsToPass.push_back(std::make_pair(VA.getLocReg(), ArgValue));
+    } else {
+      assert(VA.isMemLoc() && "Argument not register or memory");
+      report_fatal_error("Passing arguments via the stack not yet implemented");
+    }
+  }
+
+  SDValue Glue;
+
+  // Build a sequence of copy-to-reg nodes, chained and glued together.
+  for (auto &Reg : RegsToPass) {
+    Chain = DAG.getCopyToReg(Chain, DL, Reg.first, Reg.second, Glue);
+    Glue = Chain.getValue(1);
+  }
+
+  if (isa<GlobalAddressSDNode>(Callee)) {
+    Callee = lowerGlobalAddress(Callee, DAG);
+  } else if (isa<ExternalSymbolSDNode>(Callee)) {
+    report_fatal_error(
+        "lowerExternalSymbol, needed for lowerCall, not yet handled");
+  }
+
+  // The first call operand is the chain and the second is the target address.
+  SmallVector<SDValue, 8> Ops;
+  Ops.push_back(Chain);
+  Ops.push_back(Callee);
+
+  // Add argument registers to the end of the list so that they are
+  // known live into the call.
+  for (auto &Reg : RegsToPass)
+    Ops.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType()));
+
+  // Add a register mask operand representing the call-preserved registers.
+  const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
+  const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv);
+  assert(Mask && "Missing call preserved mask for calling convention");
+  Ops.push_back(DAG.getRegisterMask(Mask));
+
+  // Glue the call to the argument copies, if any.
+  if (Glue.getNode())
+    Ops.push_back(Glue);
+
+  // Emit the call.
+  SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
+  Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops);
+  Glue = Chain.getValue(1);
+
+  // Mark the end of the call, which is glued to the call itself.
+  Chain = DAG.getCALLSEQ_END(Chain,
+                             DAG.getConstant(NumBytes, DL, PtrVT, true),
+                             DAG.getConstant(0, DL, PtrVT, true),
+                             Glue, DL);
+  Glue = Chain.getValue(1);
+
+  // Assign locations to each value returned by this call.
+  SmallVector<CCValAssign, 16> RVLocs;
+  CCState RetCCInfo(CallConv, IsVarArg, MF, RVLocs, *DAG.getContext());
+  RetCCInfo.AnalyzeCallResult(Ins, RetCC_RISCV32);
+
+  // Copy all of the result registers out of their specified physreg.
+  for (auto &VA : RVLocs) {
+    // Copy the value out, gluing the copy to the end of the call sequence.
+    SDValue RetValue = DAG.getCopyFromReg(Chain, DL, VA.getLocReg(),
+                                          VA.getLocVT(), Glue);
+    Chain = RetValue.getValue(1);
+    Glue = RetValue.getValue(2);
+
+    InVals.push_back(Chain.getValue(0));
+  }
+
+  return Chain;
+}
+
 SDValue
 RISCVTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
                                  bool IsVarArg,
@@ -194,6 +324,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
     break;
   case RISCVISD::RET_FLAG:
     return "RISCVISD::RET_FLAG";
+  case RISCVISD::CALL:
+    return "RISCVISD::CALL";
   }
   return nullptr;
 }
index 2bd5717..dfb4824 100644 (file)
@@ -24,7 +24,8 @@ class RISCVSubtarget;
 namespace RISCVISD {
 enum NodeType : unsigned {
   FIRST_NUMBER = ISD::BUILTIN_OP_END,
-  RET_FLAG
+  RET_FLAG,
+  CALL
 };
 }
 
@@ -52,6 +53,8 @@ private:
                       const SmallVectorImpl<ISD::OutputArg> &Outs,
                       const SmallVectorImpl<SDValue> &OutVals, const SDLoc &DL,
                       SelectionDAG &DAG) const override;
+  SDValue LowerCall(TargetLowering::CallLoweringInfo &CLI,
+                    SmallVectorImpl<SDValue> &InVals) const override;
   bool shouldConvertConstantLoadToIntImm(const APInt &Imm,
                                          Type *Ty) const override {
     return true;
index 63178bb..5b4f4fc 100644 (file)
@@ -28,7 +28,8 @@
 
 using namespace llvm;
 
-RISCVInstrInfo::RISCVInstrInfo() : RISCVGenInstrInfo() {}
+RISCVInstrInfo::RISCVInstrInfo()
+    : RISCVGenInstrInfo(RISCV::ADJCALLSTACKDOWN, RISCV::ADJCALLSTACKUP) {}
 
 void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
                                  MachineBasicBlock::iterator MBBI,
index 2c21ad3..23f218f 100644 (file)
@@ -17,8 +17,22 @@ include "RISCVInstrFormats.td"
 // RISC-V specific DAG Nodes.
 //===----------------------------------------------------------------------===//
 
-def RetFlag : SDNode<"RISCVISD::RET_FLAG", SDTNone,
-                     [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
+def SDT_RISCVCall         : SDTypeProfile<0, -1, [SDTCisVT<0, XLenVT>]>;
+def SDT_RISCVCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>,
+                                            SDTCisVT<1, i32>]>;
+def SDT_RISCVCallSeqEnd   : SDCallSeqEnd<[SDTCisVT<0, i32>,
+                                          SDTCisVT<1, i32>]>;
+
+
+def Call         : SDNode<"RISCVISD::CALL", SDT_RISCVCall,
+                          [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
+                           SDNPVariadic]>;
+def CallSeqStart : SDNode<"ISD::CALLSEQ_START", SDT_RISCVCallSeqStart,
+                          [SDNPHasChain, SDNPOutGlue]>;
+def CallSeqEnd   : SDNode<"ISD::CALLSEQ_END", SDT_RISCVCallSeqEnd,
+                          [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>;
+def RetFlag      : SDNode<"RISCVISD::RET_FLAG", SDTNone,
+                          [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
 
 //===----------------------------------------------------------------------===//
 // Operand and SDNode transformation definitions.
@@ -340,6 +354,10 @@ let isBarrier = 1, isBranch = 1, isTerminator = 1 in
 def PseudoBR : Pseudo<(outs), (ins simm21_lsb0:$imm20), [(br bb:$imm20)]>,
                PseudoInstExpansion<(JAL X0, simm21_lsb0:$imm20)>;
 
+let isCall = 1, Defs=[X1] in
+def PseudoCALL : Pseudo<(outs), (ins GPR:$rs1), [(Call GPR:$rs1)]>,
+                 PseudoInstExpansion<(JALR X1, GPR:$rs1, 0)>;
+
 let isBarrier = 1, isReturn = 1, isTerminator = 1 in
 def PseudoRET : Pseudo<(outs), (ins), [(RetFlag)]>,
                 PseudoInstExpansion<(JALR X0, X1, 0)>;
@@ -371,3 +389,13 @@ multiclass StPat<PatFrag StoreOp, RVInst Inst> {
 defm : StPat<truncstorei8, SB>;
 defm : StPat<truncstorei16, SH>;
 defm : StPat<store, SW>;
+
+/// Other pseudo-instructions
+
+// Pessimistically assume the stack pointer will be clobbered
+let Defs = [X2], Uses = [X2] in {
+def ADJCALLSTACKDOWN : Pseudo<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+                              [(CallSeqStart timm:$amt1, timm:$amt2)]>;
+def ADJCALLSTACKUP   : Pseudo<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+                              [(CallSeqEnd timm:$amt1, timm:$amt2)]>;
+} // Defs = [X2], Uses = [X2]
index a354a34..ef0051e 100644 (file)
@@ -68,6 +68,9 @@ bool llvm::LowerRISCVMachineOperandToMCOperand(const MachineOperand &MO,
       return false;
     MCOp = MCOperand::createReg(MO.getReg());
     break;
+  case MachineOperand::MO_RegisterMask:
+    // Regmasks are like implicit defs.
+    return false;
   case MachineOperand::MO_Immediate:
     MCOp = MCOperand::createImm(MO.getImm());
     break;
index 5d64de2..cd658d7 100644 (file)
@@ -88,3 +88,9 @@ void RISCVRegisterInfo::eliminateFrameIndex(MachineBasicBlock::iterator II,
 unsigned RISCVRegisterInfo::getFrameRegister(const MachineFunction &MF) const {
   return RISCV::X8;
 }
+
+const uint32_t *
+RISCVRegisterInfo::getCallPreservedMask(const MachineFunction & /*MF*/,
+                                        CallingConv::ID /*CC*/) const {
+  return CSR_RegMask;
+}
index a9b6f07..d9de9bf 100644 (file)
@@ -25,6 +25,9 @@ struct RISCVRegisterInfo : public RISCVGenRegisterInfo {
 
   RISCVRegisterInfo(unsigned HwMode);
 
+  const uint32_t *getCallPreservedMask(const MachineFunction &MF,
+                                       CallingConv::ID) const override;
+
   const MCPhysReg *getCalleeSavedRegs(const MachineFunction *MF) const override;
 
   BitVector getReservedRegs(const MachineFunction &MF) const override;
diff --git a/llvm/test/CodeGen/RISCV/calls.ll b/llvm/test/CodeGen/RISCV/calls.ll
new file mode 100644 (file)
index 0000000..8abe5e9
--- /dev/null
@@ -0,0 +1,83 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV32I %s
+
+declare i32 @external_function(i32)
+
+define i32 @test_call_external(i32 %a) nounwind {
+; RV32I-LABEL: test_call_external:
+; RV32I:       # BB#0:
+; RV32I-NEXT:    sw ra, 12(s0)
+; RV32I-NEXT:    lui a1, %hi(external_function)
+; RV32I-NEXT:    addi a1, a1, %lo(external_function)
+; RV32I-NEXT:    jalr ra, a1, 0
+; RV32I-NEXT:    lw ra, 12(s0)
+; RV32I-NEXT:    jalr zero, ra, 0
+  %1 = call i32 @external_function(i32 %a)
+  ret i32 %1
+}
+
+define i32 @defined_function(i32 %a) nounwind {
+; RV32I-LABEL: defined_function:
+; RV32I:       # BB#0:
+; RV32I-NEXT:    addi a0, a0, 1
+; RV32I-NEXT:    jalr zero, ra, 0
+  %1 = add i32 %a, 1
+  ret i32 %1
+}
+
+define i32 @test_call_defined(i32 %a) nounwind {
+; RV32I-LABEL: test_call_defined:
+; RV32I:       # BB#0:
+; RV32I-NEXT:    sw ra, 12(s0)
+; RV32I-NEXT:    lui a1, %hi(defined_function)
+; RV32I-NEXT:    addi a1, a1, %lo(defined_function)
+; RV32I-NEXT:    jalr ra, a1, 0
+; RV32I-NEXT:    lw ra, 12(s0)
+; RV32I-NEXT:    jalr zero, ra, 0
+  %1 = call i32 @defined_function(i32 %a) nounwind
+  ret i32 %1
+}
+
+define i32 @test_call_indirect(i32 (i32)* %a, i32 %b) nounwind {
+; RV32I-LABEL: test_call_indirect:
+; RV32I:       # BB#0:
+; RV32I-NEXT:    sw ra, 12(s0)
+; RV32I-NEXT:    addi a2, a0, 0
+; RV32I-NEXT:    addi a0, a1, 0
+; RV32I-NEXT:    jalr ra, a2, 0
+; RV32I-NEXT:    lw ra, 12(s0)
+; RV32I-NEXT:    jalr zero, ra, 0
+  %1 = call i32 %a(i32 %b)
+  ret i32 %1
+}
+
+; Ensure that calls to fastcc functions aren't rejected. Such calls may be
+; introduced when compiling with optimisation.
+
+define fastcc i32 @fastcc_function(i32 %a, i32 %b) nounwind {
+; RV32I-LABEL: fastcc_function:
+; RV32I:       # BB#0:
+; RV32I-NEXT:    add a0, a0, a1
+; RV32I-NEXT:    jalr zero, ra, 0
+ %1 = add i32 %a, %b
+ ret i32 %1
+}
+
+define i32 @test_call_fastcc(i32 %a, i32 %b) nounwind {
+; RV32I-LABEL: test_call_fastcc:
+; RV32I:       # BB#0:
+; RV32I-NEXT:    sw ra, 12(s0)
+; RV32I-NEXT:    sw s1, 8(s0)
+; RV32I-NEXT:    addi s1, a0, 0
+; RV32I-NEXT:    lui a0, %hi(fastcc_function)
+; RV32I-NEXT:    addi a2, a0, %lo(fastcc_function)
+; RV32I-NEXT:    addi a0, s1, 0
+; RV32I-NEXT:    jalr ra, a2, 0
+; RV32I-NEXT:    addi a0, s1, 0
+; RV32I-NEXT:    lw s1, 8(s0)
+; RV32I-NEXT:    lw ra, 12(s0)
+; RV32I-NEXT:    jalr zero, ra, 0
+  %1 = call fastcc i32 @fastcc_function(i32 %a, i32 %b)
+  ret i32 %a
+}