[Target][ARM] Replace outdated getARMVPTBlockMask function
authorPierre-vh <pierre.vanhoutryve@arm.com>
Wed, 8 Apr 2020 10:55:09 +0000 (11:55 +0100)
committerPierre-vh <pierre.vanhoutryve@arm.com>
Tue, 12 May 2020 11:10:15 +0000 (12:10 +0100)
getARMVPTBlockMask was an outdated function that only handled basic
block masks: T, TT, TTT and TTTT. This worked fine before the MVE
VPT Block Insertion Pass improvements as it was the only kind of
masks that it could generate, but now it can generate more complex
masks that uses E predicates, so it's dangerous to use that function
to calculate VPT/VPST block masks.

I replaced it with 2 different functions:
  - expandPredBlockMask, in ARMBaseInfo. This adds an "E" or "T" at
    the end of an existing PredBlockMask.
  - recomputeVPTBlockMask, in Thumb2InstrInfo. This takes an iterator
    to a VPT/VPST instruction and recomputes its block mask by looking
    at the predicated instructions that follows it. This should be
    used to recompute a block mask after removing/adding a predicated
    instruction to the block.

The expandPredBlockMask function is pretty much imported from the MVE
VPT Blocks pass.

I had to change the ARMLowOverheadLoops and MVEVPTBlocks passes as well
so they could use these new functions.

Differential Revision: https://reviews.llvm.org/D78201

llvm/lib/Target/ARM/ARMBaseInstrInfo.h
llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
llvm/lib/Target/ARM/MVEVPTBlockPass.cpp
llvm/lib/Target/ARM/Thumb2InstrInfo.cpp
llvm/lib/Target/ARM/Thumb2InstrInfo.h
llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp
llvm/lib/Target/ARM/Utils/ARMBaseInfo.h

index 2fd6681..c414ae1 100644 (file)
@@ -504,25 +504,6 @@ bool isUncondBranchOpcode(int Opc) {
 // This table shows the VPT instruction variants, i.e. the different
 // mask field encodings, see also B5.6. Predication/conditional execution in
 // the ArmARM.
-
-
-inline static ARM::PredBlockMask getARMVPTBlockMask(unsigned NumInsts) {
-  switch (NumInsts) {
-  case 1:
-    return ARM::PredBlockMask::T;
-  case 2:
-    return ARM::PredBlockMask::TT;
-  case 3:
-    return ARM::PredBlockMask::TTT;
-  case 4:
-    return ARM::PredBlockMask::TTTT;
-  default:
-    break;
-  };
-  llvm_unreachable("Unexpected number of instruction in a VPT block");
-}
-
-
 static inline bool isVPTOpcode(int Opc) {
   return Opc == ARM::MVE_VPTv16i8 || Opc == ARM::MVE_VPTv16u8 ||
          Opc == ARM::MVE_VPTv16s8 || Opc == ARM::MVE_VPTv8i16 ||
index 44ddb4e..71da119 100644 (file)
@@ -191,6 +191,7 @@ namespace {
     SetVector<MachineInstr*> CurrentPredicate;
     SmallVector<VPTBlock, 4> VPTBlocks;
     SmallPtrSet<MachineInstr*, 4> ToRemove;
+    SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute;
     bool Revert = false;
     bool CannotTailPredicate = false;
 
@@ -1183,11 +1184,9 @@ void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
     if (Block.HasNonUniformPredicate()) {
       PredicatedMI *Divergent = Block.getDivergent();
       if (isVCTP(Divergent->MI)) {
-        // The vctp will be removed, so the size of the vpt block needs to be
-        // modified.
-        uint64_t Size = (uint64_t)getARMVPTBlockMask(Block.size() - 1);
-        Block.getVPST()->getOperand(0).setImm(Size);
-        LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
+        // The vctp will be removed, so the block mask of the VPST/VPT will need
+        // to be recomputed.
+        LoLoop.BlockMasksToRecompute.insert(Block.getVPST());
       } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
         // The VPT block has a non-uniform predicate but it's entry is guarded
         // only by a vctp, which means we:
@@ -1211,13 +1210,15 @@ void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
           ++Size;
           ++I;
         }
+        // Create a VPST with a null mask, we'll recompute it later.
         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
                                           InsertAt->getDebugLoc(),
                                           TII->get(ARM::MVE_VPST));
-        MIB.addImm((uint64_t)getARMVPTBlockMask(Size));
+        MIB.addImm(0);
         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
         LoLoop.ToRemove.insert(Block.getVPST());
+        LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
       }
     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
       // A vpt block which is only predicated upon vctp and has no internal vpr
@@ -1288,6 +1289,11 @@ void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
       I->eraseFromParent();
     }
+    for (auto *I : LoLoop.BlockMasksToRecompute) {
+      LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I);
+      recomputeVPTBlockMask(*I);
+      LLVM_DEBUG(dbgs() << "           ... done: " << *I);
+    }
   }
 
   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
index 5d98307..dc769ae 100644 (file)
@@ -94,37 +94,6 @@ static MachineInstr *findVCMPToFoldIntoVPST(MachineBasicBlock::iterator MI,
   return &*CmpMI;
 }
 
-static ARM::PredBlockMask ExpandBlockMask(ARM::PredBlockMask BlockMask,
-                                          ARMVCC::VPTCodes Kind) {
-  using PredBlockMask = ARM::PredBlockMask;
-  assert(Kind != ARMVCC::None && "Cannot expand mask with 'None'");
-  assert(countTrailingZeros((unsigned)BlockMask) != 0 &&
-         "Mask is already full");
-
-  auto ChooseMask = [&](PredBlockMask AddedThen, PredBlockMask AddedElse) {
-    return (Kind == ARMVCC::Then) ? AddedThen : AddedElse;
-  };
-
-  switch (BlockMask) {
-  case PredBlockMask::T:
-    return ChooseMask(PredBlockMask::TT, PredBlockMask::TE);
-  case PredBlockMask::TT:
-    return ChooseMask(PredBlockMask::TTT, PredBlockMask::TTE);
-  case PredBlockMask::TE:
-    return ChooseMask(PredBlockMask::TET, PredBlockMask::TEE);
-  case PredBlockMask::TTT:
-    return ChooseMask(PredBlockMask::TTTT, PredBlockMask::TTTE);
-  case PredBlockMask::TTE:
-    return ChooseMask(PredBlockMask::TTET, PredBlockMask::TTEE);
-  case PredBlockMask::TET:
-    return ChooseMask(PredBlockMask::TETT, PredBlockMask::TETE);
-  case PredBlockMask::TEE:
-    return ChooseMask(PredBlockMask::TEET, PredBlockMask::TEEE);
-  default:
-    llvm_unreachable("Unknown Mask");
-  }
-}
-
 // Advances Iter past a block of predicated instructions.
 // Returns true if it successfully skipped the whole block of predicated
 // instructions. Returns false when it stopped early (due to MaxSteps), or if
@@ -162,6 +131,22 @@ static bool IsVPRDefinedOrKilledByBlock(MachineBasicBlock::iterator Iter,
   return false;
 }
 
+// Creates a T, TT, TTT or TTTT BlockMask depending on BlockSize.
+static ARM::PredBlockMask GetInitialBlockMask(unsigned BlockSize) {
+  switch (BlockSize) {
+  case 1:
+    return ARM::PredBlockMask::T;
+  case 2:
+    return ARM::PredBlockMask::TT;
+  case 3:
+    return ARM::PredBlockMask::TTT;
+  case 4:
+    return ARM::PredBlockMask::TTTT;
+  default:
+    llvm_unreachable("Invalid BlockSize!");
+  }
+}
+
 // Given an iterator (Iter) that points at an instruction with a "Then"
 // predicate, tries to create the largest block of continuous predicated
 // instructions possible, and returns the VPT Block Mask of that block.
@@ -190,7 +175,7 @@ CreateVPTBlock(MachineBasicBlock::instr_iterator &Iter,
   });
 
   // Generate the initial BlockMask
-  ARM::PredBlockMask BlockMask = getARMVPTBlockMask(BlockSize);
+  ARM::PredBlockMask BlockMask = GetInitialBlockMask(BlockSize);
 
   // Remove VPNOTs while there's still room in the block, so we can make the
   // largest block possible.
@@ -232,7 +217,7 @@ CreateVPTBlock(MachineBasicBlock::instr_iterator &Iter,
 
       // Change the predicate and update the mask
       Iter->getOperand(OpIdx).setImm(CurrentPredicate);
-      BlockMask = ExpandBlockMask(BlockMask, CurrentPredicate);
+      BlockMask = expandPredBlockMask(BlockMask, CurrentPredicate);
 
       LLVM_DEBUG(dbgs() << "  adding : "; Iter->dump());
     }
index 843ba47..0266b50 100644 (file)
@@ -737,3 +737,34 @@ ARMVCC::VPTCodes llvm::getVPTInstrPredicate(const MachineInstr &MI,
   PredReg = MI.getOperand(PIdx+1).getReg();
   return (ARMVCC::VPTCodes)MI.getOperand(PIdx).getImm();
 }
+
+void llvm::recomputeVPTBlockMask(MachineInstr &Instr) {
+  assert(isVPTOpcode(Instr.getOpcode()) && "Not a VPST or VPT Instruction!");
+
+  MachineOperand &MaskOp = Instr.getOperand(0);
+  assert(MaskOp.isImm() && "Operand 0 is not the block mask of the VPT/VPST?!");
+
+  MachineBasicBlock::iterator Iter = ++Instr.getIterator(),
+                              End = Instr.getParent()->end();
+
+  // Verify that the instruction after the VPT/VPST is predicated (it should
+  // be), and skip it.
+  ARMVCC::VPTCodes Pred = getVPTInstrPredicate(*Iter);
+  assert(
+      Pred == ARMVCC::Then &&
+      "VPT/VPST should be followed by an instruction with a 'then' predicate!");
+  ++Iter;
+
+  // Iterate over the predicated instructions, updating the BlockMask as we go.
+  ARM::PredBlockMask BlockMask = ARM::PredBlockMask::T;
+  while (Iter != End) {
+    ARMVCC::VPTCodes Pred = getVPTInstrPredicate(*Iter);
+    if (Pred == ARMVCC::None)
+      break;
+    BlockMask = expandPredBlockMask(BlockMask, Pred);
+    ++Iter;
+  }
+
+  // Rewrite the BlockMask.
+  MaskOp.setImm((int64_t)(BlockMask));
+}
index b49a34f..ec37636 100644 (file)
@@ -78,6 +78,13 @@ inline ARMVCC::VPTCodes getVPTInstrPredicate(const MachineInstr &MI) {
   Register PredReg;
   return getVPTInstrPredicate(MI, PredReg);
 }
+
+// Recomputes the Block Mask of Instr, a VPT or VPST instruction.
+// This rebuilds the block mask of the instruction depending on the predicates
+// of the instructions following it. This should only be used after the
+// MVEVPTBlockInsertion pass has run, and should be used whenever a predicated
+// instruction is added to/removed from the block.
+void recomputeVPTBlockMask(MachineInstr &Instr);
 } // namespace llvm
 
 #endif
index 4ace61c..3356d56 100644 (file)
 
 using namespace llvm;
 namespace llvm {
+ARM::PredBlockMask expandPredBlockMask(ARM::PredBlockMask BlockMask,
+                                       ARMVCC::VPTCodes Kind) {
+  using PredBlockMask = ARM::PredBlockMask;
+  assert(Kind != ARMVCC::None && "Cannot expand a mask with None!");
+  assert(countTrailingZeros((unsigned)BlockMask) != 0 &&
+         "Mask is already full");
+
+  auto ChooseMask = [&](PredBlockMask AddedThen, PredBlockMask AddedElse) {
+    return Kind == ARMVCC::Then ? AddedThen : AddedElse;
+  };
+
+  switch (BlockMask) {
+  case PredBlockMask::T:
+    return ChooseMask(PredBlockMask::TT, PredBlockMask::TE);
+  case PredBlockMask::TT:
+    return ChooseMask(PredBlockMask::TTT, PredBlockMask::TTE);
+  case PredBlockMask::TE:
+    return ChooseMask(PredBlockMask::TET, PredBlockMask::TEE);
+  case PredBlockMask::TTT:
+    return ChooseMask(PredBlockMask::TTTT, PredBlockMask::TTTE);
+  case PredBlockMask::TTE:
+    return ChooseMask(PredBlockMask::TTET, PredBlockMask::TTEE);
+  case PredBlockMask::TET:
+    return ChooseMask(PredBlockMask::TETT, PredBlockMask::TETE);
+  case PredBlockMask::TEE:
+    return ChooseMask(PredBlockMask::TEET, PredBlockMask::TEEE);
+  default:
+    llvm_unreachable("Unknown Mask");
+  }
+}
+
 namespace ARMSysReg {
 
 // lookup system register using 12-bit SYSm value.
index b2e434f..80b7276 100644 (file)
@@ -121,6 +121,12 @@ namespace ARM {
   };
 } // namespace ARM
 
+// Expands a PredBlockMask by adding an E or a T at the end, depending on Kind.
+// e.g ExpandPredBlockMask(T, Then) = TT, ExpandPredBlockMask(TT, Else) = TTE,
+// and so on.
+ARM::PredBlockMask expandPredBlockMask(ARM::PredBlockMask BlockMask,
+                                       ARMVCC::VPTCodes Kind);
+
 inline static const char *ARMVPTPredToString(ARMVCC::VPTCodes CC) {
   switch (CC) {
   case ARMVCC::None:  return "none";