[AArch64][SVE] Optimize bitcasts between unpacked half/i16 vectors.
authorSander de Smalen <sander.desmalen@arm.com>
Mon, 19 Jul 2021 06:13:14 +0000 (07:13 +0100)
committerSander de Smalen <sander.desmalen@arm.com>
Mon, 19 Jul 2021 07:29:28 +0000 (08:29 +0100)
The case for nxv2f32/nxv2i32 was already covered by D104573.
This patch builds on top of that by making the mechanism work for
nxv2[b]f16/nxv2i16, nxv4[b]f16/nxv4i16 as well.

Reviewed By: efriedma

Differential Revision: https://reviews.llvm.org/D106138

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/sve-bitcast.ll

index 8cda50d..6ae073e 100644 (file)
@@ -1194,7 +1194,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     }
 
     // Legalize unpacked bitcasts to REINTERPRET_CAST.
-    for (auto VT : {MVT::nxv2i32, MVT::nxv2f32})
+    for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32, MVT::nxv2bf16,
+                    MVT::nxv2f16, MVT::nxv4f16, MVT::nxv2f32})
       setOperationAction(ISD::BITCAST, VT, Custom);
 
     for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) {
@@ -3520,14 +3521,16 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
   if (useSVEForFixedLengthVectorVT(OpVT))
     return LowerFixedLengthBitcastToSVE(Op, DAG);
 
-  if (OpVT == MVT::nxv2f32) {
-    if (ArgVT.isInteger()) {
+  if (OpVT.isScalableVector()) {
+    if (isTypeLegal(OpVT) && !isTypeLegal(ArgVT)) {
+      assert(OpVT.isFloatingPoint() && !ArgVT.isFloatingPoint() &&
+             "Expected int->fp bitcast!");
       SDValue ExtResult =
           DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT),
                       Op.getOperand(0));
-      return getSVESafeBitCast(MVT::nxv2f32, ExtResult, DAG);
+      return getSVESafeBitCast(OpVT, ExtResult, DAG);
     }
-    return getSVESafeBitCast(MVT::nxv2f32, Op.getOperand(0), DAG);
+    return getSVESafeBitCast(OpVT, Op.getOperand(0), DAG);
   }
 
   if (OpVT != MVT::f16 && OpVT != MVT::bf16)
@@ -16944,16 +16947,18 @@ void AArch64TargetLowering::ReplaceBITCASTResults(
     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
   SDLoc DL(N);
   SDValue Op = N->getOperand(0);
+  EVT VT = N->getValueType(0);
+  EVT SrcVT = Op.getValueType();
 
-  if (N->getValueType(0) == MVT::nxv2i32 &&
-      Op.getValueType().isFloatingPoint()) {
-    SDValue CastResult = getSVESafeBitCast(MVT::nxv2i64, Op, DAG);
-    Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::nxv2i32, CastResult));
+  if (VT.isScalableVector() && !isTypeLegal(VT) && isTypeLegal(SrcVT)) {
+    assert(!VT.isFloatingPoint() && SrcVT.isFloatingPoint() &&
+           "Expected fp->int bitcast!");
+    SDValue CastResult = getSVESafeBitCast(getSVEContainerType(VT), Op, DAG);
+    Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, CastResult));
     return;
   }
 
-  if (N->getValueType(0) != MVT::i16 ||
-      (Op.getValueType() != MVT::f16 && Op.getValueType() != MVT::bf16))
+  if (VT != MVT::i16 || (SrcVT != MVT::f16 && SrcVT != MVT::bf16))
     return;
 
   Op = SDValue(
index dda4232..bab42f3 100644 (file)
@@ -450,6 +450,70 @@ define <vscale x 8 x bfloat> @bitcast_double_to_bfloat(<vscale x 2 x double> %v)
   ret <vscale x 8 x bfloat> %bc
 }
 
+define <vscale x 2 x i16> @bitcast_short2_half_to_i16(<vscale x 2 x half> %v) {
+; CHECK-LABEL: bitcast_short2_half_to_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+  %bc = bitcast <vscale x 2 x half> %v to <vscale x 2 x i16>
+  ret <vscale x 2 x i16> %bc
+}
+
+define <vscale x 4 x i16> @bitcast_short4_half_to_i16(<vscale x 4 x half> %v) {
+; CHECK-LABEL: bitcast_short4_half_to_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+  %bc = bitcast <vscale x 4 x half> %v to <vscale x 4 x i16>
+  ret <vscale x 4 x i16> %bc
+}
+
+define <vscale x 2 x i16> @bitcast_short2_bfloat_to_i16(<vscale x 2 x bfloat> %v) #0 {
+; CHECK-LABEL: bitcast_short2_bfloat_to_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+  %bc = bitcast <vscale x 2 x bfloat> %v to <vscale x 2 x i16>
+  ret <vscale x 2 x i16> %bc
+}
+
+define <vscale x 4 x i16> @bitcast_short4_bfloat_to_i16(<vscale x 4 x bfloat> %v) #0 {
+; CHECK-LABEL: bitcast_short4_bfloat_to_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+  %bc = bitcast <vscale x 4 x bfloat> %v to <vscale x 4 x i16>
+  ret <vscale x 4 x i16> %bc
+}
+
+define <vscale x 2 x half> @bitcast_short2_i16_to_half(<vscale x 2 x i16> %v) {
+; CHECK-LABEL: bitcast_short2_i16_to_half:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+  %bc = bitcast <vscale x 2 x i16> %v to <vscale x 2 x half>
+  ret <vscale x 2 x half> %bc
+}
+
+define <vscale x 4 x half> @bitcast_short4_i16_to_half(<vscale x 4 x i16> %v) {
+; CHECK-LABEL: bitcast_short4_i16_to_half:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+  %bc = bitcast <vscale x 4 x i16> %v to <vscale x 4 x half>
+  ret <vscale x 4 x half> %bc
+}
+
+define <vscale x 2 x bfloat> @bitcast_short2_i16_to_bfloat(<vscale x 2 x i16> %v) #0 {
+; CHECK-LABEL: bitcast_short2_i16_to_bfloat:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+  %bc = bitcast <vscale x 2 x i16> %v to <vscale x 2 x bfloat>
+  ret <vscale x 2 x bfloat> %bc
+}
+
+define <vscale x 4 x bfloat> @bitcast_short4_i16_to_bfloat(<vscale x 4 x i16> %v) #0 {
+; CHECK-LABEL: bitcast_short4_i16_to_bfloat:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ret
+  %bc = bitcast <vscale x 4 x i16> %v to <vscale x 4 x bfloat>
+  ret <vscale x 4 x bfloat> %bc
+}
+
 define <vscale x 2 x i32> @bitcast_short_float_to_i32(<vscale x 2 x double> %v) #0 {
 ; CHECK-LABEL: bitcast_short_float_to_i32:
 ; CHECK:       // %bb.0: