[ARM] Add custom strict fp conversion lowering when non-strict is custom
authorJohn Brawn <john.brawn@arm.com>
Thu, 5 Dec 2019 12:44:41 +0000 (12:44 +0000)
committerJohn Brawn <john.brawn@arm.com>
Fri, 13 Dec 2019 13:00:00 +0000 (13:00 +0000)
We have custom lowering for operations converting to/from floating-point types
when we don't have hardware support for those types, and this doesn't interact
well with the target-independent legalization of the strict versions of these
operations. Fix this by adding similar custom lowering of the strict versions.

This fixes the last of the assertion failures in the CodeGen/ARM/fp-intrinsics
test, with the remaining failures due to poor instruction selection.

Differential Revision: https://reviews.llvm.org/D71127

llvm/lib/Target/ARM/ARMISelLowering.cpp

index 0151440..0b4d39e 100644 (file)
@@ -977,19 +977,26 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::FP_TO_SINT, MVT::f64, Custom);
     setOperationAction(ISD::FP_TO_UINT, MVT::f64, Custom);
     setOperationAction(ISD::FP_ROUND,   MVT::f32, Custom);
+    setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i32, Custom);
+    setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i32, Custom);
+    setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::f64, Custom);
+    setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::f64, Custom);
+    setOperationAction(ISD::STRICT_FP_ROUND,   MVT::f32, Custom);
   }
 
   if (!Subtarget->hasFP64() || !Subtarget->hasFPARMv8Base()) {
     setOperationAction(ISD::FP_EXTEND,  MVT::f64, Custom);
-    if (Subtarget->hasFullFP16())
+    setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
+    if (Subtarget->hasFullFP16()) {
       setOperationAction(ISD::FP_ROUND,  MVT::f16, Custom);
+      setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Custom);
+    }
   }
 
-  if (!Subtarget->hasFP16())
+  if (!Subtarget->hasFP16()) {
     setOperationAction(ISD::FP_EXTEND,  MVT::f32, Custom);
-
-  if (!Subtarget->hasFP64())
-    setOperationAction(ISD::FP_ROUND,  MVT::f32, Custom);
+    setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
+  }
 
   computeRegisterProperties(Subtarget->getRegisterInfo());
 
@@ -5375,17 +5382,31 @@ SDValue ARMTargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
   EVT VT = Op.getValueType();
   if (VT.isVector())
     return LowerVectorFP_TO_INT(Op, DAG);
-  if (isUnsupportedFloatingType(Op.getOperand(0).getValueType())) {
+
+  bool IsStrict = Op->isStrictFPOpcode();
+  SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
+
+  if (isUnsupportedFloatingType(SrcVal.getValueType())) {
     RTLIB::Libcall LC;
-    if (Op.getOpcode() == ISD::FP_TO_SINT)
-      LC = RTLIB::getFPTOSINT(Op.getOperand(0).getValueType(),
+    if (Op.getOpcode() == ISD::FP_TO_SINT ||
+        Op.getOpcode() == ISD::STRICT_FP_TO_SINT)
+      LC = RTLIB::getFPTOSINT(SrcVal.getValueType(),
                               Op.getValueType());
     else
-      LC = RTLIB::getFPTOUINT(Op.getOperand(0).getValueType(),
+      LC = RTLIB::getFPTOUINT(SrcVal.getValueType(),
                               Op.getValueType());
+    SDLoc Loc(Op);
     MakeLibCallOptions CallOptions;
-    return makeLibCall(DAG, LC, Op.getValueType(), Op.getOperand(0),
-                       CallOptions, SDLoc(Op)).first;
+    SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
+    SDValue Result;
+    std::tie(Result, Chain) = makeLibCall(DAG, LC, Op.getValueType(), SrcVal,
+                                          CallOptions, Loc, Chain);
+    return IsStrict ? DAG.getMergeValues({Result, Chain}, Loc) : Result;
+  }
+
+  // FIXME: Remove this when we have strict fp instruction selection patterns
+  if (IsStrict) {
+    DAG.mutateStrictFPToFP(Op.getNode());
   }
 
   return Op;
@@ -9218,6 +9239,8 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::PREFETCH:      return LowerPREFETCH(Op, DAG, Subtarget);
   case ISD::SINT_TO_FP:
   case ISD::UINT_TO_FP:    return LowerINT_TO_FP(Op, DAG);
+  case ISD::STRICT_FP_TO_SINT:
+  case ISD::STRICT_FP_TO_UINT:
   case ISD::FP_TO_SINT:
   case ISD::FP_TO_UINT:    return LowerFP_TO_INT(Op, DAG);
   case ISD::FCOPYSIGN:     return LowerFCOPYSIGN(Op, DAG);
@@ -9286,7 +9309,9 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     if (Subtarget->isTargetWindows())
       return LowerDYNAMIC_STACKALLOC(Op, DAG);
     llvm_unreachable("Don't know how to custom lower this!");
+  case ISD::STRICT_FP_ROUND:
   case ISD::FP_ROUND: return LowerFP_ROUND(Op, DAG);
+  case ISD::STRICT_FP_EXTEND:
   case ISD::FP_EXTEND: return LowerFP_EXTEND(Op, DAG);
   case ARMISD::WIN__DBZCHK: return SDValue();
   }
@@ -16276,7 +16301,8 @@ ARMTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const
 }
 
 SDValue ARMTargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
-  SDValue SrcVal = Op.getOperand(0);
+  bool IsStrict = Op->isStrictFPOpcode();
+  SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
   const unsigned DstSz = Op.getValueType().getSizeInBits();
   const unsigned SrcSz = SrcVal.getValueType().getSizeInBits();
   assert(DstSz > SrcSz && DstSz <= 64 && SrcSz >= 16 &&
@@ -16296,34 +16322,35 @@ SDValue ARMTargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
   SDLoc Loc(Op);
   RTLIB::Libcall LC;
   MakeLibCallOptions CallOptions;
-  if (SrcSz == 16) {
-    // Instruction from 16 -> 32
-    if (Subtarget->hasFP16())
-      SrcVal = DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, SrcVal);
-    // Lib call from 16 -> 32
-    else {
-      LC = RTLIB::getFPEXT(MVT::f16, MVT::f32);
+  SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
+  for (unsigned Sz = SrcSz; Sz <= 32 && Sz < DstSz; Sz *= 2) {
+    bool Supported = (Sz == 16 ? Subtarget->hasFP16() : Subtarget->hasFP64());
+    MVT SrcVT = (Sz == 16 ? MVT::f16 : MVT::f32);
+    MVT DstVT = (Sz == 16 ? MVT::f32 : MVT::f64);
+    if (Supported) {
+      if (IsStrict) {
+        SrcVal = DAG.getNode(ISD::STRICT_FP_EXTEND, Loc,
+                             {DstVT, MVT::Other}, {Chain, SrcVal});
+        Chain = SrcVal.getValue(1);
+      } else {
+        SrcVal = DAG.getNode(ISD::FP_EXTEND, Loc, DstVT, SrcVal);
+      }
+    } else {
+      LC = RTLIB::getFPEXT(SrcVT, DstVT);
       assert(LC != RTLIB::UNKNOWN_LIBCALL &&
              "Unexpected type for custom-lowering FP_EXTEND");
-      SrcVal =
-        makeLibCall(DAG, LC, MVT::f32, SrcVal, CallOptions, Loc).first;
+      std::tie(SrcVal, Chain) = makeLibCall(DAG, LC, DstVT, SrcVal, CallOptions,
+                                            Loc, Chain);
     }
   }
 
-  if (DstSz != 64)
-    return SrcVal;
-  // For sure now SrcVal is 32 bits
-  if (Subtarget->hasFP64()) // Instruction from 32 -> 64
-    return DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f64, SrcVal);
-
-  LC = RTLIB::getFPEXT(MVT::f32, MVT::f64);
-  assert(LC != RTLIB::UNKNOWN_LIBCALL &&
-         "Unexpected type for custom-lowering FP_EXTEND");
-  return makeLibCall(DAG, LC, MVT::f64, SrcVal, CallOptions, Loc).first;
+  return IsStrict ? DAG.getMergeValues({SrcVal, Chain}, Loc) : SrcVal;
 }
 
 SDValue ARMTargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
-  SDValue SrcVal = Op.getOperand(0);
+  bool IsStrict = Op->isStrictFPOpcode();
+
+  SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
   EVT SrcVT = SrcVal.getValueType();
   EVT DstVT = Op.getValueType();
   const unsigned DstSz = Op.getValueType().getSizeInBits();
@@ -16346,7 +16373,11 @@ SDValue ARMTargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
   assert(LC != RTLIB::UNKNOWN_LIBCALL &&
          "Unexpected type for custom-lowering FP_ROUND");
   MakeLibCallOptions CallOptions;
-  return makeLibCall(DAG, LC, DstVT, SrcVal, CallOptions, Loc).first;
+  SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
+  SDValue Result;
+  std::tie(Result, Chain) = makeLibCall(DAG, LC, DstVT, SrcVal, CallOptions,
+                                        Loc, Chain);
+  return IsStrict ? DAG.getMergeValues({Result, Chain}, Loc) : Result;
 }
 
 void ARMTargetLowering::lowerABS(SDNode *N, SmallVectorImpl<SDValue> &Results,