[X86][BF16] Share FP16 vector ABI with BF16
authorPhoebe Wang <phoebe.wang@intel.com>
Fri, 9 Jun 2023 01:03:54 +0000 (09:03 +0800)
committerPhoebe Wang <phoebe.wang@intel.com>
Fri, 9 Jun 2023 01:04:56 +0000 (09:04 +0800)
The ABI of BF16 is identical to FP16 rather than i16.

Fixes #62997

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D151710

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/bfloat.ll

index 83aacd6..0f2059b 100644 (file)
@@ -417,6 +417,10 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
         return Val;
       if (PartEVT.isInteger() && ValueVT.isFloatingPoint())
         return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
+
+      // Vector/Vector bitcast (e.g. <2 x bfloat> -> <2 x half>).
+      if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits())
+        return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
     }
 
     // Promoted vector extract
@@ -622,6 +626,8 @@ static SDValue widenVectorToPartType(SelectionDAG &DAG, SDValue Val,
     return SDValue();
 
   EVT ValueVT = Val.getValueType();
+  EVT PartEVT = PartVT.getVectorElementType();
+  EVT ValueEVT = ValueVT.getVectorElementType();
   ElementCount PartNumElts = PartVT.getVectorElementCount();
   ElementCount ValueNumElts = ValueVT.getVectorElementCount();
 
@@ -629,22 +635,30 @@ static SDValue widenVectorToPartType(SelectionDAG &DAG, SDValue Val,
   // fixed/scalable properties. If a target needs to widen a fixed-length type
   // to a scalable one, it should be possible to use INSERT_SUBVECTOR below.
   if (ElementCount::isKnownLE(PartNumElts, ValueNumElts) ||
-      PartNumElts.isScalable() != ValueNumElts.isScalable() ||
-      PartVT.getVectorElementType() != ValueVT.getVectorElementType())
+      PartNumElts.isScalable() != ValueNumElts.isScalable())
     return SDValue();
 
+  // Have a try for bf16 because some targets share its ABI with fp16.
+  if (ValueEVT == MVT::bf16 && PartEVT == MVT::f16) {
+    assert(DAG.getTargetLoweringInfo().isTypeLegal(PartVT) &&
+           "Cannot widen to illegal type");
+    Val = DAG.getNode(ISD::BITCAST, DL,
+                      ValueVT.changeVectorElementType(MVT::f16), Val);
+  } else if (PartEVT != ValueEVT) {
+    return SDValue();
+  }
+
   // Widening a scalable vector to another scalable vector is done by inserting
   // the vector into a larger undef one.
   if (PartNumElts.isScalable())
     return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PartVT, DAG.getUNDEF(PartVT),
                        Val, DAG.getVectorIdxConstant(0, DL));
 
-  EVT ElementVT = PartVT.getVectorElementType();
   // Vector widening case, e.g. <2 x float> -> <4 x float>.  Shuffle in
   // undef elements.
   SmallVector<SDValue, 16> Ops;
   DAG.ExtractVectorElements(Val, Ops);
-  SDValue EltUndef = DAG.getUNDEF(ElementVT);
+  SDValue EltUndef = DAG.getUNDEF(PartEVT);
   Ops.append((PartNumElts - ValueNumElts).getFixedValue(), EltUndef);
 
   // FIXME: Use CONCAT for 2x -> 4x.
index 170396d..0bab667 100644 (file)
@@ -2608,7 +2608,7 @@ MVT X86TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
 
   if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
     return getRegisterTypeForCallingConv(Context, CC,
-                                         VT.changeVectorElementTypeToInteger());
+                                         VT.changeVectorElementType(MVT::f16));
 
   return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
 }
@@ -2643,7 +2643,7 @@ unsigned X86TargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
 
   if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
     return getNumRegistersForCallingConv(Context, CC,
-                                         VT.changeVectorElementTypeToInteger());
+                                         VT.changeVectorElementType(MVT::f16));
 
   return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
 }
index ca33823..c67c947 100644 (file)
@@ -317,13 +317,13 @@ define <8 x bfloat> @addv(<8 x bfloat> %a, <8 x bfloat> %b) nounwind {
 ; SSE2-NEXT:    movq %rdx, %rax
 ; SSE2-NEXT:    shrq $48, %rax
 ; SSE2-NEXT:    movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3]
+; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
 ; SSE2-NEXT:    movq %xmm0, %r12
 ; SSE2-NEXT:    movq %r12, %rax
 ; SSE2-NEXT:    shrq $32, %rax
 ; SSE2-NEXT:    movq %rax, (%rsp) # 8-byte Spill
-; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
-; SSE2-NEXT:    movq %xmm0, %r14
+; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm1 = xmm1[1,1]
+; SSE2-NEXT:    movq %xmm1, %r14
 ; SSE2-NEXT:    movq %r14, %rbp
 ; SSE2-NEXT:    shrq $32, %rbp
 ; SSE2-NEXT:    movq %r12, %r15
@@ -543,3 +543,25 @@ define <8 x bfloat> @addv(<8 x bfloat> %a, <8 x bfloat> %b) nounwind {
   %add = fadd <8 x bfloat> %a, %b
   ret <8 x bfloat> %add
 }
+
+define <2 x bfloat> @pr62997(bfloat %a, bfloat %b) {
+; SSE2-LABEL: pr62997:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    movd %xmm1, %ecx
+; SSE2-NEXT:    pinsrw $0, %ecx, %xmm1
+; SSE2-NEXT:    pinsrw $0, %eax, %xmm0
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; SSE2-NEXT:    retq
+;
+; BF16-LABEL: pr62997:
+; BF16:       # %bb.0:
+; BF16-NEXT:    vmovd %xmm1, %eax
+; BF16-NEXT:    vmovd %xmm0, %ecx
+; BF16-NEXT:    vmovd %ecx, %xmm0
+; BF16-NEXT:    vpinsrw $1, %eax, %xmm0, %xmm0
+; BF16-NEXT:    retq
+  %1 = insertelement <2 x bfloat> undef, bfloat %a, i64 0
+  %2 = insertelement <2 x bfloat> %1, bfloat %b, i64 1
+  ret <2 x bfloat> %2
+}