R600: Fix LowerSDIV24
authorMatt Arsenault <Matthew.Arsenault@amd.com>
Thu, 24 Jul 2014 06:59:20 +0000 (06:59 +0000)
committerMatt Arsenault <Matthew.Arsenault@amd.com>
Thu, 24 Jul 2014 06:59:20 +0000 (06:59 +0000)
Use ComputeNumSignBits instead of checking for i8 / i16 which only
worked when AMDIL was lying about having legal i8 / i16.

If an integer is known to fit in 24-bits, we can
do division faster with float ops.

llvm-svn: 213843

llvm/lib/Target/R600/AMDGPUISelLowering.cpp
llvm/test/CodeGen/R600/sdiv24.ll [new file with mode: 0644]

index ffd6357..1eccaaf 100644 (file)
@@ -250,7 +250,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(TargetMachine &TM) :
   const MVT ScalarIntVTs[] = { MVT::i32, MVT::i64 };
   for (MVT VT : ScalarIntVTs) {
     setOperationAction(ISD::SREM, VT, Expand);
-    setOperationAction(ISD::SDIV, VT, Expand);
+    setOperationAction(ISD::SDIV, VT, Custom);
 
     // GPU does not have divrem function for signed or unsigned.
     setOperationAction(ISD::SDIVREM, VT, Custom);
@@ -1272,85 +1272,83 @@ SDValue AMDGPUTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
   return SDValue();
 }
 
+// This is a shortcut for integer division because we have fast i32<->f32
+// conversions, and fast f32 reciprocal instructions. The fractional part of a
+// float is enough to accurately represent up to a 24-bit integer.
 SDValue AMDGPUTargetLowering::LowerSDIV24(SDValue Op, SelectionDAG &DAG) const {
   SDLoc DL(Op);
-  EVT OVT = Op.getValueType();
+  EVT VT = Op.getValueType();
   SDValue LHS = Op.getOperand(0);
   SDValue RHS = Op.getOperand(1);
-  MVT INTTY;
-  MVT FLTTY;
-  if (!OVT.isVector()) {
-    INTTY = MVT::i32;
-    FLTTY = MVT::f32;
-  } else if (OVT.getVectorNumElements() == 2) {
-    INTTY = MVT::v2i32;
-    FLTTY = MVT::v2f32;
-  } else if (OVT.getVectorNumElements() == 4) {
-    INTTY = MVT::v4i32;
-    FLTTY = MVT::v4f32;
+  MVT IntVT = MVT::i32;
+  MVT FltVT = MVT::f32;
+
+  if (VT.isVector()) {
+    unsigned NElts = VT.getVectorNumElements();
+    IntVT = MVT::getVectorVT(MVT::i32, NElts);
+    FltVT = MVT::getVectorVT(MVT::f32, NElts);
   }
-  unsigned bitsize = OVT.getScalarType().getSizeInBits();
+
+  unsigned BitSize = VT.getScalarType().getSizeInBits();
+
   // char|short jq = ia ^ ib;
-  SDValue jq = DAG.getNode(ISD::XOR, DL, OVT, LHS, RHS);
+  SDValue jq = DAG.getNode(ISD::XOR, DL, VT, LHS, RHS);
 
   // jq = jq >> (bitsize - 2)
-  jq = DAG.getNode(ISD::SRA, DL, OVT, jq, DAG.getConstant(bitsize - 2, OVT));
+  jq = DAG.getNode(ISD::SRA, DL, VT, jq, DAG.getConstant(BitSize - 2, VT));
 
   // jq = jq | 0x1
-  jq = DAG.getNode(ISD::OR, DL, OVT, jq, DAG.getConstant(1, OVT));
+  jq = DAG.getNode(ISD::OR, DL, VT, jq, DAG.getConstant(1, VT));
 
   // jq = (int)jq
-  jq = DAG.getSExtOrTrunc(jq, DL, INTTY);
+  jq = DAG.getSExtOrTrunc(jq, DL, IntVT);
 
   // int ia = (int)LHS;
-  SDValue ia = DAG.getSExtOrTrunc(LHS, DL, INTTY);
+  SDValue ia = DAG.getSExtOrTrunc(LHS, DL, IntVT);
 
   // int ib, (int)RHS;
-  SDValue ib = DAG.getSExtOrTrunc(RHS, DL, INTTY);
+  SDValue ib = DAG.getSExtOrTrunc(RHS, DL, IntVT);
 
   // float fa = (float)ia;
-  SDValue fa = DAG.getNode(ISD::SINT_TO_FP, DL, FLTTY, ia);
+  SDValue fa = DAG.getNode(ISD::SINT_TO_FP, DL, FltVT, ia);
 
   // float fb = (float)ib;
-  SDValue fb = DAG.getNode(ISD::SINT_TO_FP, DL, FLTTY, ib);
+  SDValue fb = DAG.getNode(ISD::SINT_TO_FP, DL, FltVT, ib);
 
   // float fq = native_divide(fa, fb);
-  SDValue fq = DAG.getNode(ISD::FMUL, DL, FLTTY,
-                           fa, DAG.getNode(AMDGPUISD::RCP, DL, FLTTY, fb));
+  SDValue fq = DAG.getNode(ISD::FMUL, DL, FltVT,
+                           fa, DAG.getNode(AMDGPUISD::RCP, DL, FltVT, fb));
 
   // fq = trunc(fq);
-  fq = DAG.getNode(ISD::FTRUNC, DL, FLTTY, fq);
+  fq = DAG.getNode(ISD::FTRUNC, DL, FltVT, fq);
 
   // float fqneg = -fq;
-  SDValue fqneg = DAG.getNode(ISD::FNEG, DL, FLTTY, fq);
+  SDValue fqneg = DAG.getNode(ISD::FNEG, DL, FltVT, fq);
 
   // float fr = mad(fqneg, fb, fa);
-  SDValue fr = DAG.getNode(ISD::FADD, DL, FLTTY,
-      DAG.getNode(ISD::MUL, DL, FLTTY, fqneg, fb), fa);
+  SDValue fr = DAG.getNode(ISD::FADD, DL, FltVT,
+                           DAG.getNode(ISD::FMUL, DL, FltVT, fqneg, fb), fa);
 
   // int iq = (int)fq;
-  SDValue iq = DAG.getNode(ISD::FP_TO_SINT, DL, INTTY, fq);
+  SDValue iq = DAG.getNode(ISD::FP_TO_SINT, DL, IntVT, fq);
 
   // fr = fabs(fr);
-  fr = DAG.getNode(ISD::FABS, DL, FLTTY, fr);
+  fr = DAG.getNode(ISD::FABS, DL, FltVT, fr);
 
   // fb = fabs(fb);
-  fb = DAG.getNode(ISD::FABS, DL, FLTTY, fb);
+  fb = DAG.getNode(ISD::FABS, DL, FltVT, fb);
+
+  EVT SetCCVT = getSetCCResultType(*DAG.getContext(), VT);
 
   // int cv = fr >= fb;
-  SDValue cv;
-  if (INTTY == MVT::i32) {
-    cv = DAG.getSetCC(DL, INTTY, fr, fb, ISD::SETOGE);
-  } else {
-    cv = DAG.getSetCC(DL, INTTY, fr, fb, ISD::SETOGE);
-  }
+  SDValue cv = DAG.getSetCC(DL, SetCCVT, fr, fb, ISD::SETOGE);
+
   // jq = (cv ? jq : 0);
-  jq = DAG.getNode(ISD::SELECT, DL, OVT, cv, jq,
-      DAG.getConstant(0, OVT));
+  jq = DAG.getNode(ISD::SELECT, DL, VT, cv, jq, DAG.getConstant(0, VT));
+
   // dst = iq + jq;
-  iq = DAG.getSExtOrTrunc(iq, DL, OVT);
-  iq = DAG.getNode(ISD::ADD, DL, OVT, iq, jq);
-  return iq;
+  iq = DAG.getSExtOrTrunc(iq, DL, VT);
+  return DAG.getNode(ISD::ADD, DL, VT, iq, jq);
 }
 
 SDValue AMDGPUTargetLowering::LowerSDIV32(SDValue Op, SelectionDAG &DAG) const {
@@ -1425,19 +1423,20 @@ SDValue AMDGPUTargetLowering::LowerSDIV64(SDValue Op, SelectionDAG &DAG) const {
 SDValue AMDGPUTargetLowering::LowerSDIV(SDValue Op, SelectionDAG &DAG) const {
   EVT OVT = Op.getValueType().getScalarType();
 
-  if (OVT == MVT::i64)
-    return LowerSDIV64(Op, DAG);
+  if (OVT == MVT::i32) {
+    if (DAG.ComputeNumSignBits(Op.getOperand(0)) > 8 &&
+        DAG.ComputeNumSignBits(Op.getOperand(1)) > 8) {
+      // TODO: We technically could do this for i64, but shouldn't that just be
+      // handled by something generally reducing 64-bit division on 32-bit
+      // values to 32-bit?
+      return LowerSDIV24(Op, DAG);
+    }
 
-  if (OVT.getScalarType() == MVT::i32)
     return LowerSDIV32(Op, DAG);
-
-  if (OVT == MVT::i16 || OVT == MVT::i8) {
-    // FIXME: We should be checking for the masked bits. This isn't reached
-    // because i8 and i16 are not legal types.
-    return LowerSDIV24(Op, DAG);
   }
 
-  return SDValue(Op.getNode(), 0);
+  assert(OVT == MVT::i64);
+  return LowerSDIV64(Op, DAG);
 }
 
 SDValue AMDGPUTargetLowering::LowerSREM32(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/test/CodeGen/R600/sdiv24.ll b/llvm/test/CodeGen/R600/sdiv24.ll
new file mode 100644 (file)
index 0000000..84c9ecb
--- /dev/null
@@ -0,0 +1,120 @@
+; RUN: llc -march=r600 -mcpu=SI < %s | FileCheck -check-prefix=SI -check-prefix=FUNC %s
+; RUN: llc -march=r600 -mcpu=redwood < %s | FileCheck -check-prefix=EG -check-prefix=FUNC %s
+
+; FUNC-LABEL: @sdiv24_i8
+; SI: V_CVT_F32_I32
+; SI: V_CVT_F32_I32
+; SI: V_RCP_F32
+; SI: V_CVT_I32_F32
+
+; EG: INT_TO_FLT
+; EG-DAG: INT_TO_FLT
+; EG-DAG: RECIP_IEEE
+; EG: FLT_TO_INT
+define void @sdiv24_i8(i8 addrspace(1)* %out, i8 addrspace(1)* %in) {
+  %den_ptr = getelementptr i8 addrspace(1)* %in, i8 1
+  %num = load i8 addrspace(1) * %in
+  %den = load i8 addrspace(1) * %den_ptr
+  %result = sdiv i8 %num, %den
+  store i8 %result, i8 addrspace(1)* %out
+  ret void
+}
+
+; FUNC-LABEL: @sdiv24_i16
+; SI: V_CVT_F32_I32
+; SI: V_CVT_F32_I32
+; SI: V_RCP_F32
+; SI: V_CVT_I32_F32
+
+; EG: INT_TO_FLT
+; EG-DAG: INT_TO_FLT
+; EG-DAG: RECIP_IEEE
+; EG: FLT_TO_INT
+define void @sdiv24_i16(i16 addrspace(1)* %out, i16 addrspace(1)* %in) {
+  %den_ptr = getelementptr i16 addrspace(1)* %in, i16 1
+  %num = load i16 addrspace(1) * %in, align 2
+  %den = load i16 addrspace(1) * %den_ptr, align 2
+  %result = sdiv i16 %num, %den
+  store i16 %result, i16 addrspace(1)* %out, align 2
+  ret void
+}
+
+; FUNC-LABEL: @sdiv24_i32
+; SI: V_CVT_F32_I32
+; SI: V_CVT_F32_I32
+; SI: V_RCP_F32
+; SI: V_CVT_I32_F32
+
+; EG: INT_TO_FLT
+; EG-DAG: INT_TO_FLT
+; EG-DAG: RECIP_IEEE
+; EG: FLT_TO_INT
+define void @sdiv24_i32(i32 addrspace(1)* %out, i32 addrspace(1)* %in) {
+  %den_ptr = getelementptr i32 addrspace(1)* %in, i32 1
+  %num = load i32 addrspace(1) * %in, align 4
+  %den = load i32 addrspace(1) * %den_ptr, align 4
+  %num.i24.0 = shl i32 %num, 8
+  %den.i24.0 = shl i32 %den, 8
+  %num.i24 = ashr i32 %num.i24.0, 8
+  %den.i24 = ashr i32 %den.i24.0, 8
+  %result = sdiv i32 %num.i24, %den.i24
+  store i32 %result, i32 addrspace(1)* %out, align 4
+  ret void
+}
+
+; FUNC-LABEL: @sdiv25_i32
+; SI-NOT: V_CVT_F32_I32
+; SI-NOT: V_RCP_F32
+
+; EG-NOT: INT_TO_FLT
+; EG-NOT: RECIP_IEEE
+define void @sdiv25_i32(i32 addrspace(1)* %out, i32 addrspace(1)* %in) {
+  %den_ptr = getelementptr i32 addrspace(1)* %in, i32 1
+  %num = load i32 addrspace(1) * %in, align 4
+  %den = load i32 addrspace(1) * %den_ptr, align 4
+  %num.i24.0 = shl i32 %num, 7
+  %den.i24.0 = shl i32 %den, 7
+  %num.i24 = ashr i32 %num.i24.0, 7
+  %den.i24 = ashr i32 %den.i24.0, 7
+  %result = sdiv i32 %num.i24, %den.i24
+  store i32 %result, i32 addrspace(1)* %out, align 4
+  ret void
+}
+
+; FUNC-LABEL: @test_no_sdiv24_i32_1
+; SI-NOT: V_CVT_F32_I32
+; SI-NOT: V_RCP_F32
+
+; EG-NOT: INT_TO_FLT
+; EG-NOT: RECIP_IEEE
+define void @test_no_sdiv24_i32_1(i32 addrspace(1)* %out, i32 addrspace(1)* %in) {
+  %den_ptr = getelementptr i32 addrspace(1)* %in, i32 1
+  %num = load i32 addrspace(1) * %in, align 4
+  %den = load i32 addrspace(1) * %den_ptr, align 4
+  %num.i24.0 = shl i32 %num, 8
+  %den.i24.0 = shl i32 %den, 7
+  %num.i24 = ashr i32 %num.i24.0, 8
+  %den.i24 = ashr i32 %den.i24.0, 7
+  %result = sdiv i32 %num.i24, %den.i24
+  store i32 %result, i32 addrspace(1)* %out, align 4
+  ret void
+}
+
+; FUNC-LABEL: @test_no_sdiv24_i32_2
+; SI-NOT: V_CVT_F32_I32
+; SI-NOT: V_RCP_F32
+
+; EG-NOT: INT_TO_FLT
+; EG-NOT: RECIP_IEEE
+define void @test_no_sdiv24_i32_2(i32 addrspace(1)* %out, i32 addrspace(1)* %in) {
+  %den_ptr = getelementptr i32 addrspace(1)* %in, i32 1
+  %num = load i32 addrspace(1) * %in, align 4
+  %den = load i32 addrspace(1) * %den_ptr, align 4
+  %num.i24.0 = shl i32 %num, 7
+  %den.i24.0 = shl i32 %den, 8
+  %num.i24 = ashr i32 %num.i24.0, 7
+  %den.i24 = ashr i32 %den.i24.0, 8
+  %result = sdiv i32 %num.i24, %den.i24
+  store i32 %result, i32 addrspace(1)* %out, align 4
+  ret void
+}