[EarlyCSE] Do not CSE convergent calls in different basic blocks
authorJay Foad <jay.foad@amd.com>
Thu, 27 Apr 2023 12:43:11 +0000 (13:43 +0100)
committerJay Foad <jay.foad@amd.com>
Fri, 28 Apr 2023 13:50:48 +0000 (14:50 +0100)
"convergent" is documented as meaning that the call cannot be made
control-dependent on more values, but in practice we also require that
it cannot be made control-dependent on fewer values, e.g. it cannot be
hoisted out of the body of an "if" statement.

In code like this, if we allow CSE to combine the two calls:

  x = convergent_call();
  if (cond) {
    y = convergent_call();
    use y;
  }

then we get this:

  x = convergent_call();
  if (cond) {
    use x;
  }

This is conceptually equivalent to moving the second call out of the
body of the "if", up to the location of the first call, so it should be
disallowed.

Differential Revision: https://reviews.llvm.org/D149348

llvm/lib/Transforms/Scalar/EarlyCSE.cpp
llvm/test/CodeGen/AMDGPU/cse-convergent.ll

index 9c745a2..09fecc1 100644 (file)
@@ -318,6 +318,14 @@ static unsigned getHashValueImpl(SimpleValue Val) {
     return hash_combine(GCR->getOpcode(), GCR->getOperand(0),
                         GCR->getBasePtr(), GCR->getDerivedPtr());
 
+  // Don't CSE convergent calls in different basic blocks, because they
+  // implicitly depend on the set of threads that is currently executing.
+  if (CallInst *CI = dyn_cast<CallInst>(Inst); CI && CI->isConvergent()) {
+    return hash_combine(
+        Inst->getOpcode(), Inst->getParent(),
+        hash_combine_range(Inst->value_op_begin(), Inst->value_op_end()));
+  }
+
   // Mix in the opcode.
   return hash_combine(
       Inst->getOpcode(),
@@ -344,8 +352,16 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) {
 
   if (LHSI->getOpcode() != RHSI->getOpcode())
     return false;
-  if (LHSI->isIdenticalToWhenDefined(RHSI))
+  if (LHSI->isIdenticalToWhenDefined(RHSI)) {
+    // Convergent calls implicitly depend on the set of threads that is
+    // currently executing, so conservatively return false if they are in
+    // different basic blocks.
+    if (CallInst *CI = dyn_cast<CallInst>(LHSI);
+        CI && CI->isConvergent() && LHSI->getParent() != RHSI->getParent())
+      return false;
+
     return true;
+  }
 
   // If we're not strictly identical, we still might be a commutable instruction
   if (BinaryOperator *LHSBinOp = dyn_cast<BinaryOperator>(LHSI)) {
index 806fb06..ed59076 100644 (file)
@@ -21,15 +21,25 @@ define i32 @test(i32 %val, i32 %cond) {
 ; GCN-NEXT:    s_or_saveexec_b32 s4, -1
 ; GCN-NEXT:    v_mov_b32_dpp v2, v3 row_xmask:1 row_mask:0xf bank_mask:0xf
 ; GCN-NEXT:    s_mov_b32 exec_lo, s4
-; GCN-NEXT:    v_mov_b32_e32 v4, 0
-; GCN-NEXT:    v_mov_b32_e32 v0, v2
+; GCN-NEXT:    v_mov_b32_e32 v5, 0
+; GCN-NEXT:    v_mov_b32_e32 v4, v2
 ; GCN-NEXT:    v_cmp_eq_u32_e32 vcc_lo, 0, v1
 ; GCN-NEXT:    s_and_saveexec_b32 s4, vcc_lo
 ; GCN-NEXT:  ; %bb.1: ; %if
-; GCN-NEXT:    v_mov_b32_e32 v4, v0
+; GCN-NEXT:    s_or_saveexec_b32 s5, -1
+; GCN-NEXT:    v_mov_b32_e32 v2, 0
+; GCN-NEXT:    s_mov_b32 exec_lo, s5
+; GCN-NEXT:    v_mov_b32_e32 v3, v0
+; GCN-NEXT:    s_not_b32 exec_lo, exec_lo
+; GCN-NEXT:    v_mov_b32_e32 v3, 0
+; GCN-NEXT:    s_not_b32 exec_lo, exec_lo
+; GCN-NEXT:    s_or_saveexec_b32 s5, -1
+; GCN-NEXT:    v_mov_b32_dpp v2, v3 row_xmask:1 row_mask:0xf bank_mask:0xf
+; GCN-NEXT:    s_mov_b32 exec_lo, s5
+; GCN-NEXT:    v_mov_b32_e32 v5, v2
 ; GCN-NEXT:  ; %bb.2: ; %end
 ; GCN-NEXT:    s_or_b32 exec_lo, exec_lo, s4
-; GCN-NEXT:    v_add_nc_u32_e32 v0, v0, v4
+; GCN-NEXT:    v_add_nc_u32_e32 v0, v4, v5
 ; GCN-NEXT:    s_xor_saveexec_b32 s4, -1
 ; GCN-NEXT:    s_clause 0x1
 ; GCN-NEXT:    buffer_load_dword v2, off, s[0:3], s32