[SelectionDAG] Implement soft FP legalisation for bf16 FP_EXTEND and BF16_TO_FP
authorAlex Bradbury <asb@igalia.com>
Mon, 29 May 2023 09:32:28 +0000 (10:32 +0100)
committerAlex Bradbury <asb@igalia.com>
Mon, 29 May 2023 09:32:28 +0000 (10:32 +0100)
As discussed in D151436, it's safe to do this as a simple shift (as is
done in LegalizeDAG.cpp) rather than needing a libcall. The added test
cases for RISC-V previously just triggered an assertion.

Codegen for bfloat_to_double will be slightly improved by D151434.

Differential Revision: https://reviews.llvm.org/D151563

llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
llvm/test/CodeGen/RISCV/bfloat.ll [new file with mode: 0644]

index f1e80ce..29a1951 100644 (file)
@@ -107,6 +107,7 @@ void DAGTypeLegalizer::SoftenFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::STRICT_FP_ROUND:
     case ISD::FP_ROUND:    R = SoftenFloatRes_FP_ROUND(N); break;
     case ISD::FP16_TO_FP:  R = SoftenFloatRes_FP16_TO_FP(N); break;
+    case ISD::BF16_TO_FP:  R = SoftenFloatRes_BF16_TO_FP(N); break;
     case ISD::STRICT_FPOW:
     case ISD::FPOW:        R = SoftenFloatRes_FPOW(N); break;
     case ISD::STRICT_FPOWI:
@@ -510,10 +511,12 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
       return BitConvertToInteger(Op);
   }
 
-  // There's only a libcall for f16 -> f32, so proceed in two stages. Also, it's
-  // entirely possible for both f16 and f32 to be legal, so use the fully
-  // hard-float FP_EXTEND rather than FP16_TO_FP.
-  if (Op.getValueType() == MVT::f16 && N->getValueType(0) != MVT::f32) {
+  // There's only a libcall for f16 -> f32 and shifting is only valid for bf16
+  // -> f32, so proceed in two stages. Also, it's entirely possible for both
+  // f16 and f32 to be legal, so use the fully hard-float FP_EXTEND rather
+  // than FP16_TO_FP.
+  if ((Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16) &&
+      N->getValueType(0) != MVT::f32) {
     if (IsStrict) {
       Op = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N),
                        { MVT::f32, MVT::Other }, { Chain, Op });
@@ -523,6 +526,9 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP_EXTEND(SDNode *N) {
     }
   }
 
+  if (Op.getValueType() == MVT::bf16)
+    return SoftenFloatRes_BF16_TO_FP(N);
+
   RTLIB::Libcall LC = RTLIB::getFPEXT(Op.getValueType(), N->getValueType(0));
   assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported FP_EXTEND!");
   TargetLowering::MakeLibCallOptions CallOptions;
@@ -555,6 +561,21 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_FP16_TO_FP(SDNode *N) {
   return TLI.makeLibCall(DAG, LC, NVT, Res32, CallOptions, SDLoc(N)).first;
 }
 
+// FIXME: Should we just use 'normal' FP_EXTEND / FP_TRUNC instead of special
+// nodes?
+SDValue DAGTypeLegalizer::SoftenFloatRes_BF16_TO_FP(SDNode *N) {
+  assert(N->getValueType(0) == MVT::f32 &&
+         "Can only soften BF16_TO_FP with f32 result");
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), MVT::f32);
+  SDValue Op = N->getOperand(0);
+  SDLoc DL(N);
+  Op = DAG.getNode(ISD::ANY_EXTEND, DL, NVT,
+                   DAG.getNode(ISD::BITCAST, DL, MVT::i16, Op));
+  SDValue Res = DAG.getNode(ISD::SHL, DL, NVT, Op,
+                            DAG.getShiftAmountConstant(16, NVT, DL));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::SoftenFloatRes_FP_ROUND(SDNode *N) {
   bool IsStrict = N->isStrictFPOpcode();
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
index 09d47ca..e73b6b1 100644 (file)
@@ -560,6 +560,7 @@ private:
   SDValue SoftenFloatRes_FNEG(SDNode *N);
   SDValue SoftenFloatRes_FP_EXTEND(SDNode *N);
   SDValue SoftenFloatRes_FP16_TO_FP(SDNode *N);
+  SDValue SoftenFloatRes_BF16_TO_FP(SDNode *N);
   SDValue SoftenFloatRes_FP_ROUND(SDNode *N);
   SDValue SoftenFloatRes_FPOW(SDNode *N);
   SDValue SoftenFloatRes_FPOWI(SDNode *N);
diff --git a/llvm/test/CodeGen/RISCV/bfloat.ll b/llvm/test/CodeGen/RISCV/bfloat.ll
new file mode 100644 (file)
index 0000000..e7583a5
--- /dev/null
@@ -0,0 +1,116 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s | FileCheck %s -check-prefix=RV32I-ILP32
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s | FileCheck %s -check-prefix=RV64I-LP64
+
+; TODO: Enable codegen for hard float.
+
+define bfloat @float_to_bfloat(float %a) nounwind {
+; RV32I-ILP32-LABEL: float_to_bfloat:
+; RV32I-ILP32:       # %bb.0:
+; RV32I-ILP32-NEXT:    addi sp, sp, -16
+; RV32I-ILP32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT:    call __truncsfbf2@plt
+; RV32I-ILP32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT:    addi sp, sp, 16
+; RV32I-ILP32-NEXT:    ret
+;
+; RV64I-LP64-LABEL: float_to_bfloat:
+; RV64I-LP64:       # %bb.0:
+; RV64I-LP64-NEXT:    addi sp, sp, -16
+; RV64I-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT:    call __truncsfbf2@plt
+; RV64I-LP64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT:    addi sp, sp, 16
+; RV64I-LP64-NEXT:    ret
+  %1 = fptrunc float %a to bfloat
+  ret bfloat %1
+}
+
+define bfloat @double_to_bfloat(double %a) nounwind {
+; RV32I-ILP32-LABEL: double_to_bfloat:
+; RV32I-ILP32:       # %bb.0:
+; RV32I-ILP32-NEXT:    addi sp, sp, -16
+; RV32I-ILP32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT:    call __truncdfbf2@plt
+; RV32I-ILP32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT:    addi sp, sp, 16
+; RV32I-ILP32-NEXT:    ret
+;
+; RV64I-LP64-LABEL: double_to_bfloat:
+; RV64I-LP64:       # %bb.0:
+; RV64I-LP64-NEXT:    addi sp, sp, -16
+; RV64I-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT:    call __truncdfbf2@plt
+; RV64I-LP64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT:    addi sp, sp, 16
+; RV64I-LP64-NEXT:    ret
+  %1 = fptrunc double %a to bfloat
+  ret bfloat %1
+}
+
+define float @bfloat_to_float(bfloat %a) nounwind {
+; RV32I-ILP32-LABEL: bfloat_to_float:
+; RV32I-ILP32:       # %bb.0:
+; RV32I-ILP32-NEXT:    slli a0, a0, 16
+; RV32I-ILP32-NEXT:    ret
+;
+; RV64I-LP64-LABEL: bfloat_to_float:
+; RV64I-LP64:       # %bb.0:
+; RV64I-LP64-NEXT:    slliw a0, a0, 16
+; RV64I-LP64-NEXT:    ret
+  %1 = fpext bfloat %a to float
+  ret float %1
+}
+
+define double @bfloat_to_double(bfloat %a) nounwind {
+; RV32I-ILP32-LABEL: bfloat_to_double:
+; RV32I-ILP32:       # %bb.0:
+; RV32I-ILP32-NEXT:    addi sp, sp, -16
+; RV32I-ILP32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT:    slli a0, a0, 16
+; RV32I-ILP32-NEXT:    call __extendsfdf2@plt
+; RV32I-ILP32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT:    addi sp, sp, 16
+; RV32I-ILP32-NEXT:    ret
+;
+; RV64I-LP64-LABEL: bfloat_to_double:
+; RV64I-LP64:       # %bb.0:
+; RV64I-LP64-NEXT:    addi sp, sp, -16
+; RV64I-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT:    slli a0, a0, 48
+; RV64I-LP64-NEXT:    srli a0, a0, 32
+; RV64I-LP64-NEXT:    call __extendsfdf2@plt
+; RV64I-LP64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT:    addi sp, sp, 16
+; RV64I-LP64-NEXT:    ret
+  %1 = fpext bfloat %a to double
+  ret double %1
+}
+
+define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
+; RV32I-ILP32-LABEL: bfloat_add:
+; RV32I-ILP32:       # %bb.0:
+; RV32I-ILP32-NEXT:    addi sp, sp, -16
+; RV32I-ILP32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-ILP32-NEXT:    slli a0, a0, 16
+; RV32I-ILP32-NEXT:    slli a1, a1, 16
+; RV32I-ILP32-NEXT:    call __addsf3@plt
+; RV32I-ILP32-NEXT:    call __truncsfbf2@plt
+; RV32I-ILP32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-ILP32-NEXT:    addi sp, sp, 16
+; RV32I-ILP32-NEXT:    ret
+;
+; RV64I-LP64-LABEL: bfloat_add:
+; RV64I-LP64:       # %bb.0:
+; RV64I-LP64-NEXT:    addi sp, sp, -16
+; RV64I-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-LP64-NEXT:    slliw a0, a0, 16
+; RV64I-LP64-NEXT:    slliw a1, a1, 16
+; RV64I-LP64-NEXT:    call __addsf3@plt
+; RV64I-LP64-NEXT:    call __truncsfbf2@plt
+; RV64I-LP64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-LP64-NEXT:    addi sp, sp, 16
+; RV64I-LP64-NEXT:    ret
+  %1 = fadd bfloat %a, %b
+  ret bfloat %1
+}