aco: try sign-extending or shifting constants in propagate_constants_vop3p
authorRhys Perry <pendingchaos02@gmail.com>
Mon, 2 May 2022 13:07:03 +0000 (14:07 +0100)
committerMarge Bot <emma+marge@anholt.net>
Tue, 5 Jul 2022 16:39:56 +0000 (16:39 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16296>

src/amd/compiler/aco_optimizer.cpp

index 82dbf82..4e8a701 100644 (file)
@@ -922,40 +922,64 @@ propagate_constants_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& i
 
    /* try to fold inline constants */
    VOP3P_instruction* vop3p = &instr->vop3p();
-   /* TODO: if bits==32, we might be able to get an inline constant if we sign-extend or shift left
-    * 16 bits.
-    */
-   Operand const_lo = Operand::get_const(ctx.program->gfx_level, info.val & 0xffff, bits / 8u);
-   Operand const_hi = Operand::get_const(ctx.program->gfx_level, info.val >> 16, bits / 8u);
    bool opsel_lo = (vop3p->opsel_lo >> i) & 1;
    bool opsel_hi = (vop3p->opsel_hi >> i) & 1;
 
-   if (const_hi.isLiteral() && (opsel_lo || opsel_hi))
-      return;
-   if (const_lo.isLiteral() && !(opsel_lo && opsel_hi))
-      return;
+   Operand const_op[2];
+   bool const_opsel[2] = {false, false};
+   for (unsigned j = 0; j < 2; j++) {
+      if ((unsigned)opsel_lo != j && (unsigned)opsel_hi != j)
+         continue; /* this half is unused */
+
+      uint16_t val = info.val >> (j ? 16 : 0);
+      Operand op = Operand::get_const(ctx.program->gfx_level, val, bits / 8u);
+      if (bits == 32 && op.isLiteral()) /* try sign extension */
+         op = Operand::get_const(ctx.program->gfx_level, val | 0xffff0000, 4);
+      if (bits == 32 && op.isLiteral()) { /* try shifting left */
+         op = Operand::get_const(ctx.program->gfx_level, val << 16, 4);
+         const_opsel[j] = true;
+      }
+      if (op.isLiteral())
+         return;
+      const_op[j] = op;
+   }
+
+   Operand const_lo = const_op[0];
+   Operand const_hi = const_op[1];
+   bool const_lo_opsel = const_opsel[0];
+   bool const_hi_opsel = const_opsel[1];
 
    if (opsel_lo == opsel_hi) {
       /* use the single 16bit value */
       instr->operands[i] = opsel_lo ? const_hi : const_lo;
 
-      /* opsel must point to lo for both halves */
-      vop3p->opsel_lo &= ~(1 << i);
-      vop3p->opsel_hi &= ~(1 << i);
+      /* opsel must point the same for both halves */
+      opsel_lo = opsel_lo ? const_hi_opsel : const_lo_opsel;
+      opsel_hi = opsel_lo;
    } else if (const_lo == const_hi) {
       /* both constants are the same */
       instr->operands[i] = const_lo;
 
-      /* opsel must point to lo for both halves */
-      vop3p->opsel_lo &= ~(1 << i);
-      vop3p->opsel_hi &= ~(1 << i);
-   } else if (const_lo.constantValue() == const_hi.constantValue16(true)) {
+      /* opsel must point the same for both halves */
+      opsel_lo = const_lo_opsel;
+      opsel_hi = const_lo_opsel;
+   } else if (const_lo.constantValue16(const_lo_opsel) ==
+              const_hi.constantValue16(!const_hi_opsel)) {
       instr->operands[i] = const_hi;
 
       /* redirect opsel selection */
-      vop3p->opsel_lo ^= (1 << i);
-      vop3p->opsel_hi ^= (1 << i);
+      opsel_lo = opsel_lo ? const_hi_opsel : !const_hi_opsel;
+      opsel_hi = opsel_hi ? const_hi_opsel : !const_hi_opsel;
+   } else if (const_hi.constantValue16(const_hi_opsel) ==
+              const_lo.constantValue16(!const_lo_opsel)) {
+      instr->operands[i] = const_lo;
+
+      /* redirect opsel selection */
+      opsel_lo = opsel_lo ? !const_lo_opsel : const_lo_opsel;
+      opsel_hi = opsel_hi ? !const_lo_opsel : const_lo_opsel;
    } else if (bits == 16 && const_lo.constantValue() == (const_hi.constantValue() ^ (1 << 15))) {
+      assert(const_lo_opsel == false && const_hi_opsel == false);
+
       /* const_lo == -const_hi */
       if (!instr_info.can_use_input_modifiers[(int)instr->opcode])
          return;
@@ -966,9 +990,12 @@ propagate_constants_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& i
       vop3p->neg_hi[i] ^= opsel_hi ^ neg_lo;
 
       /* opsel must point to lo for both operands */
-      vop3p->opsel_lo &= ~(1 << i);
-      vop3p->opsel_hi &= ~(1 << i);
+      opsel_lo = false;
+      opsel_hi = false;
    }
+
+   vop3p->opsel_lo = opsel_lo ? (vop3p->opsel_lo | (1 << i)) : (vop3p->opsel_lo & ~(1 << i));
+   vop3p->opsel_hi = opsel_hi ? (vop3p->opsel_hi | (1 << i)) : (vop3p->opsel_hi & ~(1 << i));
 }
 
 bool