The ABI of BF16 is identical to FP16 rather than i16.
Fixes #62997
Reviewed By: RKSimon
Differential Revision: https://reviews.llvm.org/D151710
return Val;
if (PartEVT.isInteger() && ValueVT.isFloatingPoint())
return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
+
+ // Vector/Vector bitcast (e.g. <2 x bfloat> -> <2 x half>).
+ if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits())
+ return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
}
// Promoted vector extract
return SDValue();
EVT ValueVT = Val.getValueType();
+ EVT PartEVT = PartVT.getVectorElementType();
+ EVT ValueEVT = ValueVT.getVectorElementType();
ElementCount PartNumElts = PartVT.getVectorElementCount();
ElementCount ValueNumElts = ValueVT.getVectorElementCount();
// fixed/scalable properties. If a target needs to widen a fixed-length type
// to a scalable one, it should be possible to use INSERT_SUBVECTOR below.
if (ElementCount::isKnownLE(PartNumElts, ValueNumElts) ||
- PartNumElts.isScalable() != ValueNumElts.isScalable() ||
- PartVT.getVectorElementType() != ValueVT.getVectorElementType())
+ PartNumElts.isScalable() != ValueNumElts.isScalable())
return SDValue();
+ // Have a try for bf16 because some targets share its ABI with fp16.
+ if (ValueEVT == MVT::bf16 && PartEVT == MVT::f16) {
+ assert(DAG.getTargetLoweringInfo().isTypeLegal(PartVT) &&
+ "Cannot widen to illegal type");
+ Val = DAG.getNode(ISD::BITCAST, DL,
+ ValueVT.changeVectorElementType(MVT::f16), Val);
+ } else if (PartEVT != ValueEVT) {
+ return SDValue();
+ }
+
// Widening a scalable vector to another scalable vector is done by inserting
// the vector into a larger undef one.
if (PartNumElts.isScalable())
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PartVT, DAG.getUNDEF(PartVT),
Val, DAG.getVectorIdxConstant(0, DL));
- EVT ElementVT = PartVT.getVectorElementType();
// Vector widening case, e.g. <2 x float> -> <4 x float>. Shuffle in
// undef elements.
SmallVector<SDValue, 16> Ops;
DAG.ExtractVectorElements(Val, Ops);
- SDValue EltUndef = DAG.getUNDEF(ElementVT);
+ SDValue EltUndef = DAG.getUNDEF(PartEVT);
Ops.append((PartNumElts - ValueNumElts).getFixedValue(), EltUndef);
// FIXME: Use CONCAT for 2x -> 4x.
if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
return getRegisterTypeForCallingConv(Context, CC,
- VT.changeVectorElementTypeToInteger());
+ VT.changeVectorElementType(MVT::f16));
return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
}
if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
return getNumRegistersForCallingConv(Context, CC,
- VT.changeVectorElementTypeToInteger());
+ VT.changeVectorElementType(MVT::f16));
return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
}
; SSE2-NEXT: movq %rdx, %rax
; SSE2-NEXT: shrq $48, %rax
; SSE2-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3]
+; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
; SSE2-NEXT: movq %xmm0, %r12
; SSE2-NEXT: movq %r12, %rax
; SSE2-NEXT: shrq $32, %rax
; SSE2-NEXT: movq %rax, (%rsp) # 8-byte Spill
-; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
-; SSE2-NEXT: movq %xmm0, %r14
+; SSE2-NEXT: punpckhqdq {{.*#+}} xmm1 = xmm1[1,1]
+; SSE2-NEXT: movq %xmm1, %r14
; SSE2-NEXT: movq %r14, %rbp
; SSE2-NEXT: shrq $32, %rbp
; SSE2-NEXT: movq %r12, %r15
%add = fadd <8 x bfloat> %a, %b
ret <8 x bfloat> %add
}
+
+define <2 x bfloat> @pr62997(bfloat %a, bfloat %b) {
+; SSE2-LABEL: pr62997:
+; SSE2: # %bb.0:
+; SSE2-NEXT: movd %xmm0, %eax
+; SSE2-NEXT: movd %xmm1, %ecx
+; SSE2-NEXT: pinsrw $0, %ecx, %xmm1
+; SSE2-NEXT: pinsrw $0, %eax, %xmm0
+; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; SSE2-NEXT: retq
+;
+; BF16-LABEL: pr62997:
+; BF16: # %bb.0:
+; BF16-NEXT: vmovd %xmm1, %eax
+; BF16-NEXT: vmovd %xmm0, %ecx
+; BF16-NEXT: vmovd %ecx, %xmm0
+; BF16-NEXT: vpinsrw $1, %eax, %xmm0, %xmm0
+; BF16-NEXT: retq
+ %1 = insertelement <2 x bfloat> undef, bfloat %a, i64 0
+ %2 = insertelement <2 x bfloat> %1, bfloat %b, i64 1
+ ret <2 x bfloat> %2
+}