[X86] CombineShuffleWithExtract - handle cases with different vector extract sources
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 16 Jun 2019 08:00:41 +0000 (08:00 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 16 Jun 2019 08:00:41 +0000 (08:00 +0000)
Insert the shorter vector source into an undef vector of the longer vector source's type.

llvm-svn: 363507

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/avx512-shuffles/partial_permute.ll

index 0a9a632fdb18c772f49513700a6a1be35d302b35..78726818927d5fb942cee306f428ac9cada2c8b4 100644 (file)
@@ -31951,19 +31951,41 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
         !isa<ConstantSDNode>(V2.getOperand(1)))
       return false;
 
+    // If the src vector types aren't the same, see if we can extend
+    // one to match the other.
     SDValue Src1 = V1.getOperand(0);
     SDValue Src2 = V2.getOperand(0);
-    if (Src1.getValueType() != Src2.getValueType())
+    if ((Src1.getValueType().getScalarType() !=
+         Src2.getValueType().getScalarType()) ||
+        !DAG.getTargetLoweringInfo().isTypeLegal(Src1.getValueType()) ||
+        !DAG.getTargetLoweringInfo().isTypeLegal(Src2.getValueType()))
       return false;
 
+    unsigned Src1SizeInBits = Src1.getValueSizeInBits();
+    unsigned Src2SizeInBits = Src2.getValueSizeInBits();
+    assert(((Src1SizeInBits % Src2SizeInBits) == 0 ||
+            (Src2SizeInBits % Src1SizeInBits) == 0) &&
+           "Shuffle vector size mismatch");
+    if (Src1SizeInBits != Src2SizeInBits) {
+      if (Src1SizeInBits > Src2SizeInBits) {
+        Src2 = insertSubVector(DAG.getUNDEF(Src1.getValueType()), Src2, 0, DAG,
+                               DL, Src2SizeInBits);
+        Src2SizeInBits = Src1SizeInBits;
+      } else {
+        Src1 = insertSubVector(DAG.getUNDEF(Src2.getValueType()), Src1, 0, DAG,
+                               DL, Src1SizeInBits);
+        Src1SizeInBits = Src2SizeInBits;
+      }
+    }
+
     unsigned Offset1 = V1.getConstantOperandVal(1);
     unsigned Offset2 = V2.getConstantOperandVal(1);
-    assert(((Offset1 % VT1.getVectorNumElements()) == 0 ||
-            (Offset2 % VT2.getVectorNumElements()) == 0 ||
-            (Src1.getValueSizeInBits() % RootSizeInBits) == 0 ||
-            (Src2.getValueSizeInBits() % RootSizeInBits) == 0) &&
+    assert(((Offset1 % VT1.getVectorNumElements()) == 0 &&
+            (Offset2 % VT2.getVectorNumElements()) == 0 &&
+            (Src1SizeInBits % RootSizeInBits) == 0 &&
+            Src1SizeInBits == Src2SizeInBits) &&
            "Unexpected subvector extraction");
-    unsigned Scale = Src1.getValueSizeInBits() / RootSizeInBits;
+    unsigned Scale = Src1SizeInBits / RootSizeInBits;
 
     // Convert extraction indices to mask size.
     Offset1 /= VT1.getVectorNumElements();
index ec3f32381be1187e9d8bb257f87bdad238fcc830..b00344607b940966eb500d3855857910e8c28a5a 100644 (file)
@@ -3166,11 +3166,9 @@ define <4 x float> @test_masked_z_16xfloat_to_4xfloat_perm_mask2(<16 x float> %v
 define <4 x float> @test_16xfloat_to_4xfloat_perm_mask3(<16 x float> %vec) {
 ; CHECK-LABEL: test_16xfloat_to_4xfloat_perm_mask3:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vmovaps {{.*#+}} ymm1 = [0,2,4,6,4,6,6,7]
-; CHECK-NEXT:    vpermps %ymm0, %ymm1, %ymm1
-; CHECK-NEXT:    vextractf32x4 $2, %zmm0, %xmm2
-; CHECK-NEXT:    vmovaps {{.*#+}} xmm0 = [2,5,3,7]
-; CHECK-NEXT:    vpermi2ps %xmm1, %xmm2, %xmm0
+; CHECK-NEXT:    vmovaps {{.*#+}} xmm1 = [10,18,11,22]
+; CHECK-NEXT:    vpermt2ps %zmm0, %zmm1, %zmm0
+; CHECK-NEXT:    # kill: def $xmm0 killed $xmm0 killed $zmm0
 ; CHECK-NEXT:    vzeroupper
 ; CHECK-NEXT:    retq
   %res = shufflevector <16 x float> %vec, <16 x float> undef, <4 x i32> <i32 10, i32 2, i32 11, i32 6>
@@ -3179,14 +3177,11 @@ define <4 x float> @test_16xfloat_to_4xfloat_perm_mask3(<16 x float> %vec) {
 define <4 x float> @test_masked_16xfloat_to_4xfloat_perm_mask3(<16 x float> %vec, <4 x float> %vec2, <4 x float> %mask) {
 ; CHECK-LABEL: test_masked_16xfloat_to_4xfloat_perm_mask3:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vmovaps {{.*#+}} ymm3 = [0,2,4,6,4,6,6,7]
-; CHECK-NEXT:    vpermps %ymm0, %ymm3, %ymm3
-; CHECK-NEXT:    vextractf32x4 $2, %zmm0, %xmm0
-; CHECK-NEXT:    vmovaps {{.*#+}} xmm4 = [2,5,3,7]
-; CHECK-NEXT:    vpermi2ps %xmm3, %xmm0, %xmm4
+; CHECK-NEXT:    vmovaps {{.*#+}} xmm3 = [10,18,11,22]
+; CHECK-NEXT:    vpermi2ps %zmm0, %zmm0, %zmm3
 ; CHECK-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; CHECK-NEXT:    vcmpeqps %xmm0, %xmm2, %k1
-; CHECK-NEXT:    vblendmps %xmm4, %xmm1, %xmm0 {%k1}
+; CHECK-NEXT:    vblendmps %xmm3, %xmm1, %xmm0 {%k1}
 ; CHECK-NEXT:    vzeroupper
 ; CHECK-NEXT:    retq
   %shuf = shufflevector <16 x float> %vec, <16 x float> undef, <4 x i32> <i32 10, i32 2, i32 11, i32 6>
@@ -3198,13 +3193,12 @@ define <4 x float> @test_masked_16xfloat_to_4xfloat_perm_mask3(<16 x float> %vec
 define <4 x float> @test_masked_z_16xfloat_to_4xfloat_perm_mask3(<16 x float> %vec, <4 x float> %mask) {
 ; CHECK-LABEL: test_masked_z_16xfloat_to_4xfloat_perm_mask3:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vmovaps {{.*#+}} ymm2 = [0,2,4,6,4,6,6,7]
-; CHECK-NEXT:    vpermps %ymm0, %ymm2, %ymm2
-; CHECK-NEXT:    vextractf32x4 $2, %zmm0, %xmm3
-; CHECK-NEXT:    vmovaps {{.*#+}} xmm0 = [2,5,3,7]
-; CHECK-NEXT:    vxorps %xmm4, %xmm4, %xmm4
-; CHECK-NEXT:    vcmpeqps %xmm4, %xmm1, %k1
-; CHECK-NEXT:    vpermi2ps %xmm2, %xmm3, %xmm0 {%k1} {z}
+; CHECK-NEXT:    vbroadcasti128 {{.*#+}} ymm2 = [10,18,11,22,10,18,11,22]
+; CHECK-NEXT:    # ymm2 = mem[0,1,0,1]
+; CHECK-NEXT:    vxorps %xmm3, %xmm3, %xmm3
+; CHECK-NEXT:    vcmpeqps %xmm3, %xmm1, %k1
+; CHECK-NEXT:    vpermt2ps %zmm0, %zmm2, %zmm0 {%k1} {z}
+; CHECK-NEXT:    # kill: def $xmm0 killed $xmm0 killed $zmm0
 ; CHECK-NEXT:    vzeroupper
 ; CHECK-NEXT:    retq
   %shuf = shufflevector <16 x float> %vec, <16 x float> undef, <4 x i32> <i32 10, i32 2, i32 11, i32 6>