[X86] printConstant - fix asm comment issue when broadcasting from a wider constant...
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Wed, 31 May 2023 11:28:17 +0000 (12:28 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Wed, 31 May 2023 11:28:17 +0000 (12:28 +0100)
In cases where a broadcast op is loading from a constant entry wider than the broadcast element, we were incorrectly printing the entire entry and not just the lower bits referenced by the broadcast.

llvm/lib/Target/X86/X86MCInstLower.cpp
llvm/test/CodeGen/X86/vector-shuffle-combining-avx2.ll

index 2cbc31e..92cd095 100644 (file)
@@ -1523,7 +1523,8 @@ static void printConstant(const APFloat &Flt, raw_ostream &CS) {
   CS << Str;
 }
 
-static void printConstant(const Constant *COp, raw_ostream &CS) {
+static void printConstant(const Constant *COp, unsigned BitWidth,
+                          raw_ostream &CS) {
   if (isa<UndefValue>(COp)) {
     CS << "u";
   } else if (auto *CI = dyn_cast<ConstantInt>(COp)) {
@@ -1534,7 +1535,10 @@ static void printConstant(const Constant *COp, raw_ostream &CS) {
     Type *EltTy = CDS->getElementType();
     bool IsInteger = EltTy->isIntegerTy();
     bool IsFP = EltTy->isHalfTy() || EltTy->isFloatTy() || EltTy->isDoubleTy();
-    for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) {
+    unsigned EltBits = EltTy->getPrimitiveSizeInBits();
+    unsigned E = std::min(BitWidth / EltBits, CDS->getNumElements());
+    assert((BitWidth % EltBits) == 0 && "Broadcast element size mismatch");
+    for (unsigned I = 0; I != E; ++I) {
       if (I != 0)
         CS << ",";
       if (IsInteger)
@@ -1914,7 +1918,8 @@ static void addConstantComments(const MachineInstr *MI,
                ++i) {
             if (i != 0 || l != 0)
               CS << ",";
-            printConstant(CV->getOperand(i), CS);
+            printConstant(CV->getOperand(i),
+                          CV->getType()->getPrimitiveSizeInBits(), CS);
           }
         }
         CS << ">";
@@ -1957,40 +1962,40 @@ static void addConstantComments(const MachineInstr *MI,
     assert(MI->getNumOperands() >= (1 + X86::AddrNumOperands) &&
            "Unexpected number of operands!");
     if (auto *C = getConstantFromPool(*MI, MI->getOperand(1 + X86::AddrDisp))) {
-      int NumElts;
+      int NumElts, EltBits;
       switch (MI->getOpcode()) {
       default: llvm_unreachable("Invalid opcode");
-      case X86::MOVDDUPrm:          NumElts = 2;  break;
-      case X86::VMOVDDUPrm:         NumElts = 2;  break;
-      case X86::VMOVDDUPZ128rm:     NumElts = 2;  break;
-      case X86::VBROADCASTSSrm:     NumElts = 4;  break;
-      case X86::VBROADCASTSSYrm:    NumElts = 8;  break;
-      case X86::VBROADCASTSSZ128rm: NumElts = 4;  break;
-      case X86::VBROADCASTSSZ256rm: NumElts = 8;  break;
-      case X86::VBROADCASTSSZrm:    NumElts = 16; break;
-      case X86::VBROADCASTSDYrm:    NumElts = 4;  break;
-      case X86::VBROADCASTSDZ256rm: NumElts = 4;  break;
-      case X86::VBROADCASTSDZrm:    NumElts = 8;  break;
-      case X86::VPBROADCASTBrm:     NumElts = 16; break;
-      case X86::VPBROADCASTBYrm:    NumElts = 32; break;
-      case X86::VPBROADCASTBZ128rm: NumElts = 16; break;
-      case X86::VPBROADCASTBZ256rm: NumElts = 32; break;
-      case X86::VPBROADCASTBZrm:    NumElts = 64; break;
-      case X86::VPBROADCASTDrm:     NumElts = 4;  break;
-      case X86::VPBROADCASTDYrm:    NumElts = 8;  break;
-      case X86::VPBROADCASTDZ128rm: NumElts = 4;  break;
-      case X86::VPBROADCASTDZ256rm: NumElts = 8;  break;
-      case X86::VPBROADCASTDZrm:    NumElts = 16; break;
-      case X86::VPBROADCASTQrm:     NumElts = 2;  break;
-      case X86::VPBROADCASTQYrm:    NumElts = 4;  break;
-      case X86::VPBROADCASTQZ128rm: NumElts = 2;  break;
-      case X86::VPBROADCASTQZ256rm: NumElts = 4;  break;
-      case X86::VPBROADCASTQZrm:    NumElts = 8;  break;
-      case X86::VPBROADCASTWrm:     NumElts = 8;  break;
-      case X86::VPBROADCASTWYrm:    NumElts = 16; break;
-      case X86::VPBROADCASTWZ128rm: NumElts = 8;  break;
-      case X86::VPBROADCASTWZ256rm: NumElts = 16; break;
-      case X86::VPBROADCASTWZrm:    NumElts = 32; break;
+      case X86::MOVDDUPrm:          NumElts = 2;  EltBits = 64; break;
+      case X86::VMOVDDUPrm:         NumElts = 2;  EltBits = 64; break;
+      case X86::VMOVDDUPZ128rm:     NumElts = 2;  EltBits = 64; break;
+      case X86::VBROADCASTSSrm:     NumElts = 4;  EltBits = 32; break;
+      case X86::VBROADCASTSSYrm:    NumElts = 8;  EltBits = 32; break;
+      case X86::VBROADCASTSSZ128rm: NumElts = 4;  EltBits = 32; break;
+      case X86::VBROADCASTSSZ256rm: NumElts = 8;  EltBits = 32; break;
+      case X86::VBROADCASTSSZrm:    NumElts = 16; EltBits = 32; break;
+      case X86::VBROADCASTSDYrm:    NumElts = 4;  EltBits = 64; break;
+      case X86::VBROADCASTSDZ256rm: NumElts = 4;  EltBits = 64; break;
+      case X86::VBROADCASTSDZrm:    NumElts = 8;  EltBits = 64; break;
+      case X86::VPBROADCASTBrm:     NumElts = 16; EltBits = 8; break;
+      case X86::VPBROADCASTBYrm:    NumElts = 32; EltBits = 8; break;
+      case X86::VPBROADCASTBZ128rm: NumElts = 16; EltBits = 8; break;
+      case X86::VPBROADCASTBZ256rm: NumElts = 32; EltBits = 8; break;
+      case X86::VPBROADCASTBZrm:    NumElts = 64; EltBits = 8; break;
+      case X86::VPBROADCASTDrm:     NumElts = 4;  EltBits = 32; break;
+      case X86::VPBROADCASTDYrm:    NumElts = 8;  EltBits = 32; break;
+      case X86::VPBROADCASTDZ128rm: NumElts = 4;  EltBits = 32; break;
+      case X86::VPBROADCASTDZ256rm: NumElts = 8;  EltBits = 32; break;
+      case X86::VPBROADCASTDZrm:    NumElts = 16; EltBits = 32; break;
+      case X86::VPBROADCASTQrm:     NumElts = 2;  EltBits = 64; break;
+      case X86::VPBROADCASTQYrm:    NumElts = 4;  EltBits = 64; break;
+      case X86::VPBROADCASTQZ128rm: NumElts = 2;  EltBits = 64; break;
+      case X86::VPBROADCASTQZ256rm: NumElts = 4;  EltBits = 64; break;
+      case X86::VPBROADCASTQZrm:    NumElts = 8;  EltBits = 64; break;
+      case X86::VPBROADCASTWrm:     NumElts = 8;  EltBits = 16; break;
+      case X86::VPBROADCASTWYrm:    NumElts = 16; EltBits = 16; break;
+      case X86::VPBROADCASTWZ128rm: NumElts = 8;  EltBits = 16; break;
+      case X86::VPBROADCASTWZ256rm: NumElts = 16; EltBits = 16; break;
+      case X86::VPBROADCASTWZrm:    NumElts = 32; EltBits = 16; break;
       }
 
       std::string Comment;
@@ -2001,7 +2006,7 @@ static void addConstantComments(const MachineInstr *MI,
       for (int i = 0; i != NumElts; ++i) {
         if (i != 0)
           CS << ",";
-        printConstant(C, CS);
+        printConstant(C, EltBits, CS);
       }
       CS << "]";
       OutStreamer.AddComment(CS.str());
index 91f550b..2976313 100644 (file)
@@ -874,7 +874,7 @@ define void @PR63030(ptr %p0) {
 ; X86-AVX2:       # %bb.0:
 ; X86-AVX2-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-AVX2-NEXT:    vmovaps (%eax), %xmm0
-; X86-AVX2-NEXT:    vmovddup {{.*#+}} xmm1 = [3,0,2,0,3,0,2,0]
+; X86-AVX2-NEXT:    vmovddup {{.*#+}} xmm1 = [3,0,3,0]
 ; X86-AVX2-NEXT:    # xmm1 = mem[0,0]
 ; X86-AVX2-NEXT:    vpermpd {{.*#+}} ymm2 = ymm0[1,1,0,0]
 ; X86-AVX2-NEXT:    vblendps {{.*#+}} ymm1 = ymm2[0,1],ymm1[2,3],ymm2[4,5,6,7]
@@ -899,7 +899,7 @@ define void @PR63030(ptr %p0) {
 ; X64-AVX2-LABEL: PR63030:
 ; X64-AVX2:       # %bb.0:
 ; X64-AVX2-NEXT:    vmovaps (%rdi), %xmm0
-; X64-AVX2-NEXT:    vmovddup {{.*#+}} xmm1 = [3,2,3,2]
+; X64-AVX2-NEXT:    vmovddup {{.*#+}} xmm1 = [3,3]
 ; X64-AVX2-NEXT:    # xmm1 = mem[0,0]
 ; X64-AVX2-NEXT:    vpermpd {{.*#+}} ymm2 = ymm0[1,1,0,0]
 ; X64-AVX2-NEXT:    vblendps {{.*#+}} ymm1 = ymm2[0,1],ymm1[2,3],ymm2[4,5,6,7]