ir3/ra: Add proper support for multiple destinations
authorConnor Abbott <cwabbott0@gmail.com>
Thu, 2 Dec 2021 18:41:40 +0000 (19:41 +0100)
committerMarge Bot <emma+marge@anholt.net>
Thu, 10 Mar 2022 17:15:29 +0000 (17:15 +0000)
We weren't considering the other destinations when allocating a
destination, so we could allocate overlapping destinations. This wasn't
done before because we never had a need for it, but the subgroup
reduction macros will need it.

The trickiest part of this is that we have to rewrite the
compress_regs_left fallback, because we may have to move around the
other already-allocated destinations. We now have a list of destinations
to (re)allocate in addition to the popped live intervals. For the rest
of the destination handling, we can just bail out if the proposed spot
for something overlaps another destination, but for the fallback we have
to handle all the cases gracefully. I also added support for odd
combinations of multiple destinations where some of them are tied, which
we'll use in the next commit to handle early-clobber destinations and
which will actually be used because one of the destinations of the
subgroup reduction macro will be early-clobber. The result is that the
order of intervals to allocate is now a lot more complicated.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14107>

src/freedreno/ir3/ir3_ra.c

index 54e3fad..988fa55 100644 (file)
@@ -734,9 +734,52 @@ ra_move_interval(struct ra_ctx *ctx, struct ra_file *file,
    ra_push_interval(ctx, file, &temp, dst);
 }
 
+static struct ra_file *
+ra_get_file(struct ra_ctx *ctx, struct ir3_register *reg)
+{
+   if (reg->flags & IR3_REG_SHARED)
+      return &ctx->shared;
+   else if (ctx->merged_regs || !(reg->flags & IR3_REG_HALF))
+      return &ctx->full;
+   else
+      return &ctx->half;
+}
+
+
+/* Returns true if the proposed spot for "dst" or a killed source overlaps a
+ * destination that's been allocated.
+ */
 static bool
-get_reg_specified(struct ra_file *file, struct ir3_register *reg,
-                  physreg_t physreg, bool is_source)
+check_dst_overlap(struct ra_ctx *ctx, struct ra_file *file,
+                  struct ir3_register *dst, physreg_t start,
+                  physreg_t end)
+{
+   struct ir3_instruction *instr = dst->instr;
+
+   ra_foreach_dst (other_dst, instr) {
+      /* We assume only destinations before the current one have been allocated.
+       */
+      if (other_dst == dst)
+         break;
+
+      if (ra_get_file(ctx, other_dst) != file)
+         continue;
+
+      struct ra_interval *other_interval = &ctx->intervals[other_dst->name];
+      assert(!other_interval->interval.parent);
+      physreg_t other_start = other_interval->physreg_start;
+      physreg_t other_end = other_interval->physreg_end;
+
+      if (other_end > start && end > other_start)
+         return true;
+   }
+
+   return false;
+}
+
+static bool
+get_reg_specified(struct ra_ctx *ctx, struct ra_file *file,
+                  struct ir3_register *reg, physreg_t physreg, bool is_source)
 {
    for (unsigned i = 0; i < reg_size(reg); i++) {
       if (!BITSET_TEST(is_source ? file->available_to_evict : file->available,
@@ -744,6 +787,9 @@ get_reg_specified(struct ra_file *file, struct ir3_register *reg,
          return false;
    }
 
+   if (check_dst_overlap(ctx, file, reg, physreg, physreg + reg_size(reg)))
+      return false;
+
    return true;
 }
 
@@ -799,7 +845,11 @@ try_evict_regs(struct ra_ctx *ctx, struct ra_file *file,
             size--;
          }
 
-         if (size >= conflicting->physreg_end - conflicting->physreg_start) {
+         unsigned conflicting_size =
+            conflicting->physreg_end - conflicting->physreg_start;
+         if (size >= conflicting_size &&
+             !check_dst_overlap(ctx, file, reg, avail_start, avail_start +
+                                conflicting_size)) {
             for (unsigned i = 0;
                  i < conflicting->physreg_end - conflicting->physreg_start; i++)
                BITSET_CLEAR(available_to_evict, avail_start + i);
@@ -845,6 +895,10 @@ try_evict_regs(struct ra_ctx *ctx, struct ra_file *file,
          if (!killed_available)
             continue;
 
+         if (check_dst_overlap(ctx, file, reg, killed->physreg_start,
+                               killed->physreg_end))
+            continue;
+
          /* Check for alignment if one is a full reg */
          if ((!(killed->interval.reg->flags & IR3_REG_HALF) ||
               !(conflicting->interval.reg->flags & IR3_REG_HALF)) &&
@@ -890,15 +944,16 @@ removed_interval_cmp(const void *_i1, const void *_i2)
 
    /* We sort the registers as follows:
     *
-    * |--------------------------------------------------------------------|
-    * |                    |             |             |                   |
-    * |  Half live-through | Half killed | Full killed | Full live-through |
-    * |                    |             |             |                   |
-    * |--------------------------------------------------------------------|
-    *                        |                 |
-    *                        |   Destination   |
-    *                        |                 |
-    *                        |-----------------|
+    * |--------------------------------------------------------------------------------------|
+    * |               |                  |        |        |                  |              |
+    * |  Half         | Half             | Half   | Full   | Full             | Full         |
+    * |  live-through | tied destination | killed | killed | tied destination | live-through |
+    * |               |                  |        |        |                  |              |
+    * |--------------------------------------------------------------------------------------|
+    *                                    |                 |
+    *                                    |   Destination   |
+    *                                    |                 |
+    *                                    |-----------------|
     *
     * Half-registers have to be first so that they stay in the low half of
     * the register file. Then half and full killed must stay together so that
@@ -930,6 +985,37 @@ removed_interval_cmp(const void *_i1, const void *_i2)
    return 0;
 }
 
+static int
+dsts_cmp(const void *_i1, const void *_i2)
+{
+   struct ir3_register *i1 = *(struct ir3_register *const *) _i1;
+   struct ir3_register *i2 = *(struct ir3_register *const *) _i2;
+
+   /* Treat tied destinations as-if they are live-through sources, and normal
+    * destinations as killed sources.
+    */
+   unsigned i1_align = reg_elem_size(i1);
+   unsigned i2_align = reg_elem_size(i2);
+   if (i1_align > i2_align)
+      return 1;
+   if (i1_align < i2_align)
+      return -1;
+
+   if (i1_align == 1) {
+      if (!i2->tied)
+         return -1;
+      if (!i1->tied)
+         return 1;
+   } else {
+      if (!i2->tied)
+         return 1;
+      if (!i1->tied)
+         return -1;
+   }
+
+   return 0;
+}
+
 /* "Compress" all the live intervals so that there is enough space for the
  * destination register. As there can be gaps when a more-aligned interval
  * follows a less-aligned interval, this also sorts them to remove such
@@ -940,29 +1026,100 @@ removed_interval_cmp(const void *_i1, const void *_i2)
  * Return the physreg to use.
  */
 static physreg_t
-compress_regs_left(struct ra_ctx *ctx, struct ra_file *file, unsigned size,
-                   unsigned align, bool is_source)
+compress_regs_left(struct ra_ctx *ctx, struct ra_file *file,
+                   struct ir3_register *reg)
 {
+   unsigned align = reg_elem_size(reg);
    DECLARE_ARRAY(struct ra_removed_interval, intervals);
    intervals_count = intervals_sz = 0;
    intervals = NULL;
 
+   DECLARE_ARRAY(struct ir3_register *, dsts);
+   dsts_count = dsts_sz = 0;
+   dsts = NULL;
+   array_insert(ctx, dsts, reg);
+   bool dst_inserted[reg->instr->dsts_count];
+
+   unsigned dst_size = reg->tied ? 0 : reg_size(reg);
+   unsigned tied_dst_size = reg->tied ? reg_size(reg) : 0;
+   unsigned half_dst_size = 0, tied_half_dst_size = 0;
+   if (align == 1) {
+      half_dst_size = dst_size;
+      tied_half_dst_size = tied_dst_size;
+   }
+
    unsigned removed_size = 0, removed_half_size = 0;
+   unsigned removed_killed_size = 0, removed_killed_half_size = 0;
    unsigned file_size =
       align == 1 ? MIN2(file->size, RA_HALF_SIZE) : file->size;
    physreg_t start_reg = 0;
 
    foreach_interval_rev_safe (interval, file) {
+      /* We'll check if we can compact the intervals starting here. */
+      physreg_t candidate_start = interval->physreg_end;
+
+      /* Check if there are any other destinations we need to compact. */
+      ra_foreach_dst_n (other_dst, n, reg->instr) {
+         if (other_dst == reg)
+            break;
+         if (ra_get_file(ctx, other_dst) != file)
+            continue;
+         if (dst_inserted[n])
+            continue;
+
+         struct ra_interval *other_interval = &ctx->intervals[other_dst->name];
+         /* if the destination partially overlaps this interval, we need to
+          * extend candidate_start to the end.
+          */
+         if (other_interval->physreg_start < candidate_start) {
+            candidate_start = MAX2(candidate_start,
+                                   other_interval->physreg_end);
+            continue;
+         }
+
+         dst_inserted[n] = true;
+
+         /* dst intervals with a tied killed source are considered attached to
+          * that source. Don't actually insert them. This means we have to
+          * update them below if their tied source moves.
+          */
+         if (other_dst->tied) {
+            struct ra_interval *tied_interval =
+               &ctx->intervals[other_dst->tied->def->name];
+            if (tied_interval->is_killed)
+               continue;
+         }
+
+         d("popping destination %u physreg %u\n",
+           other_interval->interval.reg->name,
+           other_interval->physreg_start);
+
+         array_insert(ctx, dsts, other_dst);
+         unsigned interval_size = reg_size(other_dst);
+         if (other_dst->tied) {
+            tied_dst_size += interval_size;
+            if (other_interval->interval.reg->flags & IR3_REG_HALF)
+               tied_half_dst_size += interval_size;
+         } else {
+            dst_size += interval_size;
+            if (other_interval->interval.reg->flags & IR3_REG_HALF)
+               half_dst_size += interval_size;
+         }
+      }
+
       /* Check if we can sort the intervals *after* this one and have enough
-       * space leftover to accomodate "size" units. Also check that we have
-       * enough space leftover for half-registers, if we're inserting a
-       * half-register (otherwise we only shift any half-registers down so they
-       * should be safe).
+       * space leftover to accomodate all intervals, keeping in mind that killed
+       * sources overlap non-tied destinations. Also check that we have enough
+       * space leftover for half-registers, if we're inserting a half-register
+       * (otherwise we only shift any half-registers down so they should be
+       * safe).
        */
-      if (interval->physreg_end + size + removed_size <= file->size &&
+      if (candidate_start + removed_size + tied_dst_size +
+          MAX2(removed_killed_size, dst_size) <= file->size &&
           (align != 1 ||
-           interval->physreg_end + size + removed_half_size <= file_size)) {
-         start_reg = interval->physreg_end;
+           candidate_start + removed_half_size + tied_half_dst_size +
+           MAX2(removed_killed_half_size, half_dst_size) <= file_size)) {
+         start_reg = candidate_start;
          break;
       }
 
@@ -971,20 +1128,23 @@ compress_regs_left(struct ra_ctx *ctx, struct ra_file *file, unsigned size,
        */
       assert(!interval->frozen);
 
-      /* Killed sources don't count because they go at the end and can
-       * overlap the register we're trying to add, unless it's a source.
+      /* Killed sources are different because they go at the end and can
+       * overlap the register we're trying to add.
        */
-      if (!interval->is_killed || is_source) {
-         removed_size += interval->physreg_end - interval->physreg_start;
-         if (interval->interval.reg->flags & IR3_REG_HALF) {
-            removed_half_size += interval->physreg_end -
-               interval->physreg_start;
-         }
+      unsigned interval_size = interval->physreg_end - interval->physreg_start;
+      if (interval->is_killed) {
+         removed_killed_size += interval_size;
+         if (interval->interval.reg->flags & IR3_REG_HALF)
+            removed_killed_half_size += interval_size;
+      } else {
+         removed_size += interval_size;
+         if (interval->interval.reg->flags & IR3_REG_HALF)
+            removed_half_size += interval_size;
       }
 
       /* Now that we've done the accounting, pop this off */
-      d("popping interval %u physreg %u\n", interval->interval.reg->name,
-        interval->physreg_start);
+      d("popping interval %u physreg %u%s\n", interval->interval.reg->name,
+        interval->physreg_start, interval->is_killed ? ", killed" : "");
       array_insert(ctx, intervals, ra_pop_interval(ctx, file, interval));
    }
 
@@ -993,48 +1153,130 @@ compress_regs_left(struct ra_ctx *ctx, struct ra_file *file, unsigned size,
     */
 
    qsort(intervals, intervals_count, sizeof(*intervals), removed_interval_cmp);
+   qsort(dsts, dsts_count, sizeof(*dsts), dsts_cmp);
 
-   physreg_t physreg = start_reg;
+   physreg_t live_reg = start_reg;
+   physreg_t dst_reg = (physreg_t)~0;
    physreg_t ret_reg = (physreg_t)~0;
-   for (unsigned i = 0; i < intervals_count; i++) {
-      if (ret_reg == (physreg_t)~0 &&
-          ((intervals[i].interval->is_killed && !is_source) ||
-           !(intervals[i].interval->interval.reg->flags & IR3_REG_HALF))) {
-         ret_reg = ALIGN(physreg, align);
+   unsigned dst_index = 0;
+   unsigned live_index = 0;
+
+   /* We have two lists of intervals to process, live intervals and destination
+    * intervals. Process them in the order of the disgram in insert_cmp().
+    */
+   while (live_index < intervals_count || dst_index < dsts_count) {
+      bool process_dst;
+      if (live_index == intervals_count) {
+         process_dst = true;
+      } else if (dst_index == dsts_count) {
+         process_dst = false;
+      } else {
+         struct ir3_register *dst = dsts[dst_index];
+         struct ra_interval *live_interval = intervals[live_index].interval;
+
+         bool live_half = live_interval->interval.reg->flags & IR3_REG_HALF;
+         bool live_killed = live_interval->is_killed;
+         bool dst_half = dst->flags & IR3_REG_HALF;
+         bool dst_tied = dst->tied;
+
+         if (live_half && !live_killed) {
+            /* far-left of diagram. */
+            process_dst = false;
+         } else if (dst_half && dst_tied) {
+            /* mid-left of diagram. */
+            process_dst = true;
+         } else if (!dst_tied) {
+            /* bottom of disagram. */
+            process_dst = true;
+         } else if (live_killed) {
+            /* middle of diagram. */
+            process_dst = false;
+         } else if (!dst_half && dst_tied) {
+            /* mid-right of diagram. */
+            process_dst = true;
+         } else {
+            /* far right of diagram. */
+            assert(!live_killed && !live_half);
+            process_dst = false;
+         }
       }
 
-      if (ret_reg != (physreg_t)~0 &&
-          (is_source || !intervals[i].interval->is_killed)) {
-         physreg = MAX2(physreg, ret_reg + size);
+      struct ir3_register *cur_reg =
+         process_dst ? dsts[dst_index] :
+         intervals[live_index].interval->interval.reg;
+
+      physreg_t physreg;
+      if (process_dst && !cur_reg->tied) {
+         if (dst_reg == (physreg_t)~0)
+            dst_reg = live_reg;
+         physreg = dst_reg;
+      } else {
+         physreg = live_reg;
+         struct ra_interval *live_interval = intervals[live_index].interval;
+         bool live_killed = live_interval->is_killed;
+         /* If this is live-through and we've processed the destinations, we
+          * need to make sure we take into account any overlapping destinations.
+          */
+         if (!live_killed && dst_reg != (physreg_t)~0)
+            physreg = MAX2(physreg, dst_reg);
       }
 
-      if (!(intervals[i].interval->interval.reg->flags & IR3_REG_HALF)) {
+      if (!(cur_reg->flags & IR3_REG_HALF))
          physreg = ALIGN(physreg, 2);
-      }
 
-      if (physreg + intervals[i].size >
-          reg_file_size(file, intervals[i].interval->interval.reg)) {
+      d("pushing reg %u physreg %u\n", cur_reg->name, physreg);
+
+      unsigned interval_size = reg_size(cur_reg);
+      if (physreg + interval_size >
+          reg_file_size(file, cur_reg)) {
          d("ran out of room for interval %u!\n",
-           intervals[i].interval->interval.reg->name);
+           cur_reg->name);
          unreachable("reg pressure calculation was wrong!");
          return 0;
       }
 
-      d("pushing interval %u physreg %u\n",
-        intervals[i].interval->interval.reg->name, physreg);
-      ra_push_interval(ctx, file, &intervals[i], physreg);
+      if (process_dst) {
+         if (cur_reg == reg) {
+            ret_reg = physreg;
+         } else {
+            struct ra_interval *interval = &ctx->intervals[cur_reg->name];
+            interval->physreg_start = physreg;
+            interval->physreg_end = physreg + interval_size;
+         }
+         dst_index++;
+      } else {
+         ra_push_interval(ctx, file, &intervals[live_index], physreg);
+         live_index++;
+      }
+
+      physreg += interval_size;
 
-      physreg += intervals[i].size;
+      if (process_dst && !cur_reg->tied) {
+         dst_reg = physreg;
+      } else {
+         live_reg = physreg;
+      }
    }
 
-   if (ret_reg == (physreg_t)~0)
-      ret_reg = physreg;
+   /* If we shuffled around a tied source that is killed, we may have to update
+    * its corresponding destination since we didn't insert it above.
+    */
+   ra_foreach_dst (dst, reg->instr) {
+      if (dst == reg)
+         break;
 
-   ret_reg = ALIGN(ret_reg, align);
-   if (ret_reg + size > file_size) {
-      d("ran out of room for the new interval!\n");
-      unreachable("reg pressure calculation was wrong!");
-      return 0;
+      struct ir3_register *tied = dst->tied;
+      if (!tied)
+         continue;
+
+      struct ra_interval *tied_interval = &ctx->intervals[tied->def->name];
+      if (!tied_interval->is_killed)
+         continue;
+
+      struct ra_interval *dst_interval = &ctx->intervals[dst->name];
+      unsigned dst_size = reg_size(dst);
+      dst_interval->physreg_start = ra_interval_get_physreg(tied_interval);
+      dst_interval->physreg_end = dst_interval->physreg_start + dst_size;
    }
 
    return ret_reg;
@@ -1060,7 +1302,8 @@ update_affinity(struct ra_file *file, struct ir3_register *reg,
  * a round-robin algorithm to reduce false dependencies.
  */
 static physreg_t
-find_best_gap(struct ra_file *file, unsigned file_size, unsigned size,
+find_best_gap(struct ra_ctx *ctx, struct ra_file *file,
+              struct ir3_register *dst, unsigned file_size, unsigned size,
               unsigned align, bool is_source)
 {
    /* This can happen if we create a very large merge set. Just bail out in that
@@ -1084,6 +1327,11 @@ find_best_gap(struct ra_file *file, unsigned file_size, unsigned size,
       }
 
       if (is_available) {
+         is_available =
+            !check_dst_overlap(ctx, file, dst, candidate, candidate + size);
+      }
+
+      if (is_available) {
          file->start = (candidate + size) % file_size;
          return candidate;
       }
@@ -1096,17 +1344,6 @@ find_best_gap(struct ra_file *file, unsigned file_size, unsigned size,
    return (physreg_t)~0;
 }
 
-static struct ra_file *
-ra_get_file(struct ra_ctx *ctx, struct ir3_register *reg)
-{
-   if (reg->flags & IR3_REG_SHARED)
-      return &ctx->shared;
-   else if (ctx->merged_regs || !(reg->flags & IR3_REG_HALF))
-      return &ctx->full;
-   else
-      return &ctx->half;
-}
-
 /* This is the main entrypoint for picking a register. Pick a free register
  * for "reg", shuffling around sources if necessary. In the normal case where
  * "is_source" is false, this register can overlap with killed sources
@@ -1126,7 +1363,7 @@ get_reg(struct ra_ctx *ctx, struct ra_file *file, struct ir3_register *reg,
          reg->merge_set->preferred_reg + reg->merge_set_offset;
       if (preferred_reg < file_size &&
           preferred_reg % reg_elem_size(reg) == 0 &&
-          get_reg_specified(file, reg, preferred_reg, is_source))
+          get_reg_specified(ctx, file, reg, preferred_reg, is_source))
          return preferred_reg;
    }
 
@@ -1137,7 +1374,8 @@ get_reg(struct ra_ctx *ctx, struct ra_file *file, struct ir3_register *reg,
    unsigned size = reg_size(reg);
    if (reg->merge_set && reg->merge_set->preferred_reg == (physreg_t)~0 &&
        size < reg->merge_set->size) {
-      physreg_t best_reg = find_best_gap(file, file_size, reg->merge_set->size,
+      physreg_t best_reg = find_best_gap(ctx, file, reg, file_size,
+                                         reg->merge_set->size,
                                          reg->merge_set->alignment, is_source);
       if (best_reg != (physreg_t)~0u) {
          best_reg += reg->merge_set_offset;
@@ -1160,14 +1398,14 @@ get_reg(struct ra_ctx *ctx, struct ra_file *file, struct ir3_register *reg,
             physreg_t src_physreg = ra_interval_get_physreg(src_interval);
             if (src_physreg % reg_elem_size(reg) == 0 &&
                 src_physreg + size <= file_size &&
-                get_reg_specified(file, reg, src_physreg, is_source))
+                get_reg_specified(ctx, file, reg, src_physreg, is_source))
                return src_physreg;
          }
       }
    }
 
    physreg_t best_reg =
-      find_best_gap(file, file_size, size, reg_elem_size(reg), is_source);
+      find_best_gap(ctx, file, reg, file_size, size, reg_elem_size(reg), is_source);
    if (best_reg != (physreg_t)~0u) {
       return best_reg;
    }
@@ -1195,8 +1433,7 @@ get_reg(struct ra_ctx *ctx, struct ra_file *file, struct ir3_register *reg,
    }
 
    /* Use the dumb fallback only if try_evict_regs() fails. */
-   return compress_regs_left(ctx, file, reg_size(reg), reg_elem_size(reg),
-                             is_source);
+   return compress_regs_left(ctx, file, reg);
 }
 
 static void
@@ -1625,7 +1862,7 @@ handle_precolored_source(struct ra_ctx *ctx, struct ir3_register *src)
     * anything unless it overlaps with our precolored physreg, so we don't
     * have to worry about evicting other precolored sources.
     */
-   if (!get_reg_specified(file, src, physreg, true)) {
+   if (!get_reg_specified(ctx, file, src, physreg, true)) {
       unsigned eviction_count;
       if (!try_evict_regs(ctx, file, src, physreg, &eviction_count, true,
                           false)) {