Started to implement unstructured branches in the compiler
authorBenjamin Segovia <devnull@localhost>
Tue, 1 May 2012 20:01:42 +0000 (20:01 +0000)
committerKeith Packard <keithp@keithp.com>
Fri, 10 Aug 2012 23:16:50 +0000 (16:16 -0700)
backend/src/backend/context.cpp
backend/src/backend/context.hpp
backend/src/ir/function.cpp
backend/src/ir/function.hpp
backend/src/ir/value.cpp

index a38b4c7..2c8df7d 100644 (file)
@@ -55,6 +55,7 @@ namespace gbe
     this->buildPatchList();
     this->buildArgList();
     this->buildUsedLabels();
+    this->buildJIPs();
     this->emitCode();
     return this->kernel;
   }
@@ -183,6 +184,100 @@ namespace gbe
     });
   }
 
+  // The idea is that foward branches can by-pass the target of previous
+  // forward branches. Since we run in SIMD mode, we must be sure that we are
+  // not skipping some computations. The idea is therefore to put JOIN points at
+  // the head of each block and to restrict the distance where to jump when
+  // taking a forward branch. We traverse the blocks top to bottom and use a
+  // O(n^2) stupid algorithm to track down which branches we can by-pass
+  void Context::buildJIPs(void) {
+    using namespace ir;
+
+    // Linearly store the branch target for each block and its own label
+    const LabelIndex noTarget(fn.labelNum());
+    vector<std::pair<LabelIndex, LabelIndex>> braTargets;
+    int32_t curr = 0, blockNum = fn.blockNum();
+    braTargets.resize(blockNum);
+
+    // If some blocks are unused we mark them as such by setting their own label
+    // as "invalid" (== noTarget)
+    for (auto &bb : braTargets) bb = std::make_pair(noTarget, noTarget);
+
+    fn.foreachBlock([&](const BasicBlock &bb) {
+      const LabelIndex ownLabel = bb.getLabelIndex();
+      const Instruction *last = bb.getLastInstruction();
+      if (last->getOpcode() != OP_BRA)
+        braTargets[curr++] = std::make_pair(ownLabel, noTarget);
+      else {
+        const BranchInstruction *bra = cast<BranchInstruction>(last);
+        braTargets[curr++] = std::make_pair(ownLabel, bra->getLabelIndex());
+      }
+    });
+
+    // For each block, we also figure out if the JOIN point (at the label
+    // instruction location) needs a branch to bypass useless computations
+    vector<LabelIndex> joinTargets;
+    joinTargets.resize(fn.labelNum());
+    for (auto &bb : joinTargets) bb = noTarget;
+
+    // We store here the labels bypassed by the current branch
+    vector<LabelIndex> bypassedLabels;
+    bypassedLabels.resize(blockNum);
+
+    // Now retraverse the blocks and figure out all JIPs
+    for (int32_t blockID = 0; blockID < blockNum; ++blockID) {
+      const LabelIndex ownLabel = braTargets[blockID].first;
+      const LabelIndex target = braTargets[blockID].second;
+      if (ownLabel == noTarget) continue; // unused block
+      if (target == noTarget) continue; // no branch at all
+      if (target <= ownLabel) continue; // bwd branch: nothing to do
+
+      // Traverse all previous blocks and see if we bypass their target
+      uint32_t bypassedNum = 0;
+      uint32_t JIP = target;
+      for (int32_t prevID = blockID-1; prevID >= 0; --prevID) {
+        const LabelIndex prevTarget = braTargets[prevID].second;
+        if (prevTarget == noTarget) continue; // no branch
+        if (prevTarget >= target) continue; // complete bypass
+        if (prevTarget <= ownLabel) continue; // branch falls before
+        bypassedLabels[bypassedNum++] = prevTarget;
+        JIP = min(uint32_t(JIP), uint32_t(prevTarget));
+      }
+
+      // We now have the (possibly) updated JIP for the branch
+      const BasicBlock &bb = fn.getBlock(ownLabel);
+      const Instruction *insn = bb.getLastInstruction();
+      GBE_ASSERT(insn->isMemberOf<BranchInstruction>() == true);
+      JIPs.insert(std::make_pair(insn, LabelIndex(JIP)));
+
+      // No bypassed targets
+      if (bypassedNum == 0) continue;
+
+      // When we have several bypassed targets, we must simply sort them and
+      // chain them such target_n points to target_{n+1}
+      bypassedLabels[bypassedNum++] = ownLabel;
+      std::sort(&bypassedLabels[0], &bypassedLabels[bypassedNum]);
+
+      // Bypassed labels have a JIP now. However, we will only insert the
+      // instructions later since *several* branches can bypass the same label.
+      // For that reason, we must consider the *minimum* JIP
+      for (uint32_t bypassedID = 0; bypassedID < bypassedNum-1; ++bypassedID) {
+        const LabelIndex curr = bypassedLabels[bypassedID];
+        const LabelIndex next = bypassedLabels[bypassedID+1];
+        joinTargets[curr] = min(joinTargets[curr], next);
+      }
+    }
+
+    // Now we also processed all JOIN points (i.e. each label). We can insert
+    // the label instructions that have a JIP
+    for (uint32_t label = 0; label < fn.labelNum(); ++label) {
+      const LabelIndex target = joinTargets[label];
+      if (target == noTarget) continue;
+      const Instruction *insn = fn.getLabelInstruction(LabelIndex(label));
+      JIPs.insert(std::make_pair(insn, target));
+    }
+  }
+
   bool Context::isScalarReg(const ir::Register &reg) const {
     GBE_ASSERT(fn.getProfile() == ir::Profile::PROFILE_OCL);
     if (fn.getInput(reg) != NULL) return true;
index a88fd1d..9c990b9 100644 (file)
@@ -27,6 +27,7 @@
 
 #include "sys/platform.hpp"
 #include "sys/set.hpp"
+#include "sys/map.hpp"
 #include "ir/instruction.hpp"
 #include <string>
 
@@ -35,7 +36,6 @@ namespace ir {
 
   class Unit;        // Contains the complete program
   class Function;    // We compile a function into a kernel
-  class Register;    // We compile a function into a kernel
   class Liveness;    // Describes liveness of each ir function register
   class FunctionDAG; // Describes the instruction dependencies
 
@@ -73,19 +73,26 @@ namespace gbe
     void buildArgList(void);
     /*! Build the sets of used labels */
     void buildUsedLabels(void);
+    /*! Build JIPs for each branch and possibly labels. Can be different from
+     *  the branch target due to unstructured branches
+     */
+    void buildJIPs(void);
     /*! Indicate if a register is scalar or not */
     bool isScalarReg(const ir::Register &reg) const;
     /*! Build the instruction stream */
     virtual void emitCode(void) = 0;
     /*! Allocate a new empty kernel */
     virtual Kernel *allocateKernel(void) = 0;
+    /*! Provide for each branch and label the label index target */
+    typedef map<const ir::Instruction*, ir::LabelIndex> JIPMap;
     const ir::Unit &unit;           //!< Unit that contains the kernel
     const ir::Function &fn;         //!< Function to compile
     std::string name;               //!< Name of the kernel to compile
     Kernel *kernel;                 //!< Kernel we are building
     ir::Liveness *liveness;         //!< Liveness info for the variables
-    ir::FunctionDAG *dag;           //!< Complete DAG of values on the function
-    set<ir::LabelIndex> usedLabels; //!< Set of all labels actually used
+    ir::FunctionDAG *dag;           //!< Graph of values on the function
+    set<ir::LabelIndex> usedLabels; //!< Set of all used labels
+    JIPMap JIPs;                    //!< Where to jump all labels / branches
     uint32_t simdWidth;             //!< Number of lanes per HW threads
   };
 
index e2a233f..bf27f6f 100644 (file)
@@ -74,6 +74,15 @@ namespace ir {
         *newBra = BRA(newIndex);
       newBra->replace(&insn);
     });
+
+    // Reset the label to block mapping
+    this->labels.resize(last);
+    foreachBlock([&](BasicBlock &bb) {
+      const Instruction *first = bb.getFirstInstruction();
+      const LabelInstruction *label = cast<LabelInstruction>(first);
+      const LabelIndex index = label->getLabelIndex();
+      this->labels[index] = &bb;
+    });
   }
 
   LabelIndex Function::newLabel(void) {
@@ -177,13 +186,12 @@ namespace ir {
       out << "decl_output %" << fn.getOutput(i) << std::endl;
     out << "## " << fn.blockNum() << " block"
         << plural(fn.blockNum()) << " ##" << std::endl;
-    for (uint32_t i = 0; i < fn.blockNum(); ++i) {
-      const BasicBlock &bb = fn.getBlock(i);
+    fn.foreachBlock([&](const BasicBlock &bb) {
       bb.foreach([&out] (const Instruction &insn) {
         out << insn << std::endl;
       });
       out << std::endl;
-    }
+    });
     out << ".end_function" << std::endl;
     return out;
   }
index 22d8e33..a2243d2 100644 (file)
@@ -205,11 +205,21 @@ namespace ir {
       GBE_ASSERT(ID < outputNum());
       return outputs[ID];
     }
-    /*! Get block ID */
-    INLINE const BasicBlock &getBlock(uint32_t ID) const {
-      GBE_ASSERT(ID < blockNum());
-      GBE_ASSERT(blocks[ID] != NULL);
-      return *blocks[ID];
+    /*! Get function the entry point block */
+    INLINE const BasicBlock &getTopBlock(void) const {
+     GBE_ASSERT(blockNum() > 0);
+      return *blocks[0];
+    }
+    /*! Get block from its label */
+    INLINE const BasicBlock &getBlock(LabelIndex label) const {
+      GBE_ASSERT(label < labelNum() && labels[label] != NULL);
+      return *labels[label];
+    }
+    /*! Get the label instruction from its label index */
+    INLINE const LabelInstruction *getLabelInstruction(LabelIndex index) const {
+      const BasicBlock *bb = this->labels[index];
+      const Instruction *first = bb->getFirstInstruction();
+      return cast<LabelInstruction>(first);
     }
     /*! Get the first index of the special registers and number of them */
     uint32_t getFirstSpecialReg(void) const;
index 2143e6b..1d40272 100644 (file)
@@ -165,7 +165,7 @@ namespace ir {
     const uint32_t inputNum = fn.inputNum();
 
     // The first block must also transfer the function arguments
-    const BasicBlock &top = fn.getBlock(0);
+    const BasicBlock &top = fn.getTopBlock();
     const Liveness::BlockInfo &info = this->liveness.getBlockInfo(&top);
     GBE_ASSERT(defMap.contains(&top) == true);
     auto blockDefMap = defMap.find(&top)->second;