[X86] LowerTRUNCATE - avoid creating extract_subvector(bitcast(vec)) patterns
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 31 May 2022 13:30:28 +0000 (14:30 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 31 May 2022 13:30:56 +0000 (14:30 +0100)
We have a generic DAG combine to attempt to fold extract_subvector(bitcast(vec)) -> bitcast(extract_subvector(vec)) but if we create these patterns late in lowering then we often miss them.

Noticed while investigating Issue #55648 which gets caught in an infinite loop trying to split extract_subvector(bitcast(vselect()) patterns - this doesn't fix the issue yet but reduces the regressions from the WIP fix.

llvm/lib/Target/X86/X86ISelLowering.cpp

index 6b38e14..324161e 100644 (file)
@@ -21934,27 +21934,25 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
   assert(VT.is128BitVector() && InVT.is256BitVector() && "Unexpected types!");
 
   if ((VT == MVT::v4i32) && (InVT == MVT::v4i64)) {
-    In = DAG.getBitcast(MVT::v8i32, In);
-
     // On AVX2, v4i64 -> v4i32 becomes VPERMD.
     if (Subtarget.hasInt256()) {
       static const int ShufMask[] = {0, 2, 4, 6, -1, -1, -1, -1};
+      In = DAG.getBitcast(MVT::v8i32, In);
       In = DAG.getVectorShuffle(MVT::v8i32, DL, In, In, ShufMask);
       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, In,
                          DAG.getIntPtrConstant(0, DL));
     }
 
-    SDValue OpLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i32, In,
+    SDValue OpLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i64, In,
                                DAG.getIntPtrConstant(0, DL));
-    SDValue OpHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i32, In,
-                               DAG.getIntPtrConstant(4, DL));
+    SDValue OpHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i64, In,
+                               DAG.getIntPtrConstant(2, DL));
     static const int ShufMask[] = {0, 2, 4, 6};
-    return DAG.getVectorShuffle(VT, DL, OpLo, OpHi, ShufMask);
+    return DAG.getVectorShuffle(VT, DL, DAG.getBitcast(MVT::v4i32, OpLo),
+                                DAG.getBitcast(MVT::v4i32, OpHi), ShufMask);
   }
 
   if ((VT == MVT::v8i16) && (InVT == MVT::v8i32)) {
-    In = DAG.getBitcast(MVT::v32i8, In);
-
     // On AVX2, v8i32 -> v8i16 becomes PSHUFB.
     if (Subtarget.hasInt256()) {
       // The PSHUFB mask:
@@ -21962,27 +21960,30 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
                                       -1, -1, -1, -1, -1, -1, -1, -1,
                                       16, 17, 20, 21, 24, 25, 28, 29,
                                       -1, -1, -1, -1, -1, -1, -1, -1 };
+      In = DAG.getBitcast(MVT::v32i8, In);
       In = DAG.getVectorShuffle(MVT::v32i8, DL, In, In, ShufMask1);
       In = DAG.getBitcast(MVT::v4i64, In);
 
       static const int ShufMask2[] = {0, 2, -1, -1};
       In = DAG.getVectorShuffle(MVT::v4i64, DL, In, In, ShufMask2);
-      return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i16,
-                         DAG.getBitcast(MVT::v16i16, In),
-                         DAG.getIntPtrConstant(0, DL));
+      In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i64, In,
+                       DAG.getIntPtrConstant(0, DL));
+      return DAG.getBitcast(MVT::v8i16, In);
     }
 
-    SDValue OpLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, In,
+    SDValue OpLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i32, In,
                                DAG.getIntPtrConstant(0, DL));
-    SDValue OpHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, In,
-                               DAG.getIntPtrConstant(16, DL));
+    SDValue OpHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i32, In,
+                               DAG.getIntPtrConstant(4, DL));
 
     // The PSHUFB mask:
-    static const int ShufMask1[] = {0,  1,  4,  5,  8,  9, 12, 13,
-                                   -1, -1, -1, -1, -1, -1, -1, -1};
+    static const int ShufMask1[] = {0, 2, 4, 6, -1, -1, -1, -1};
+
+    OpLo = DAG.getBitcast(MVT::v8i16, OpLo);
+    OpHi = DAG.getBitcast(MVT::v8i16, OpHi);
 
-    OpLo = DAG.getVectorShuffle(MVT::v16i8, DL, OpLo, OpLo, ShufMask1);
-    OpHi = DAG.getVectorShuffle(MVT::v16i8, DL, OpHi, OpHi, ShufMask1);
+    OpLo = DAG.getVectorShuffle(MVT::v8i16, DL, OpLo, OpLo, ShufMask1);
+    OpHi = DAG.getVectorShuffle(MVT::v8i16, DL, OpHi, OpHi, ShufMask1);
 
     OpLo = DAG.getBitcast(MVT::v4i32, OpLo);
     OpHi = DAG.getBitcast(MVT::v4i32, OpHi);