JIT: Insert readbacks eagerly in physical promotion (#87809)
authorJakob Botsch Nielsen <Jakob.botsch.nielsen@gmail.com>
Tue, 20 Jun 2023 17:43:57 +0000 (19:43 +0200)
committerGitHub <noreply@github.com>
Tue, 20 Jun 2023 17:43:57 +0000 (19:43 +0200)
Physical promotion currently inserts readbacks (reading the struct local
back into the field local) as COMMAs right before a local that needs it.
This is not right for QMARKs where it may mean we only end up reading
back the local in one of the branches.

This change makes physical promotion insert readbacks as new statements
before any statement that is going to need it. While we could do this
for QMARKs only, it is done for any statement indiscriminately since it
has two benefits:
1. It allows forward-sub to kick in for the readbacks, which can lead to
   a contained LCL_FLD
2. It stops us from disabling local copy prop by avoiding the creation
   of embedded stores.

The existing logic is still necessary to keep in case the readback was
marked within the same tree.

Fix #87508

src/coreclr/jit/promotion.cpp
src/coreclr/jit/promotion.h
src/coreclr/jit/promotiondecomposition.cpp
src/tests/JIT/Directed/physicalpromotion/readbackbeforeqmark.cs [new file with mode: 0644]
src/tests/JIT/Directed/physicalpromotion/readbackbeforeqmark.csproj [new file with mode: 0644]

index 7bc514d..e376c64 100644 (file)
@@ -1775,6 +1775,8 @@ void ReplaceVisitor::StartBlock(BasicBlock* block)
             assert(rep.NeedsWriteBack);
         }
     }
+
+    assert(m_numPendingReadBacks == 0);
 #endif
 
     // OSR locals and parameters may need an initial read back, which we mark
@@ -1802,11 +1804,11 @@ void ReplaceVisitor::StartBlock(BasicBlock* block)
 
         for (size_t i = 0; i < agg->Replacements.size(); i++)
         {
-            Replacement& rep   = agg->Replacements[i];
-            rep.NeedsWriteBack = false;
+            Replacement& rep = agg->Replacements[i];
+            ClearNeedsWriteBack(rep);
             if (m_liveness->IsReplacementLiveIn(block, agg->LclNum, (unsigned)i))
             {
-                rep.NeedsReadBack = true;
+                SetNeedsReadBack(rep);
                 JITDUMP("  V%02u (%s) marked\n", rep.LclNum, rep.Description);
             }
             else
@@ -1879,14 +1881,75 @@ void ReplaceVisitor::EndBlock()
                             m_currentBlock->bbNum);
                 }
 
-                rep.NeedsReadBack = false;
+                ClearNeedsReadBack(rep);
             }
 
-            rep.NeedsWriteBack = true;
+            SetNeedsWriteBack(rep);
         }
     }
 
-    m_hasPendingReadBacks = false;
+    assert(m_numPendingReadBacks == 0);
+}
+
+//------------------------------------------------------------------------
+// StartStatement:
+//   Handle starting replacements within a specified statement.
+//
+// Parameters:
+//   stmt - The statement
+//
+void ReplaceVisitor::StartStatement(Statement* stmt)
+{
+    m_currentStmt       = stmt;
+    m_madeChanges       = false;
+    m_mayHaveForwardSub = false;
+
+    if (m_numPendingReadBacks == 0)
+    {
+        return;
+    }
+
+    // If we have pending readbacks then insert them as new statements for any
+    // local that the statement is using. We could leave this up to ReplaceLocal
+    // but do it here for three reasons:
+    // 1. For QMARKs we cannot actually leave it up to ReplaceLocal since the
+    // local may be conditionally executed
+    // 2. This allows forward-sub to kick in
+    // 3. Creating embedded stores in ReplaceLocal disables local copy prop for
+    //    that local (see ReplaceLocal).
+
+    for (GenTreeLclVarCommon* lcl : stmt->LocalsTreeList())
+    {
+        if (lcl->TypeIs(TYP_STRUCT))
+        {
+            continue;
+        }
+
+        AggregateInfo* agg = m_aggregates[lcl->GetLclNum()];
+        if (agg == nullptr)
+        {
+            continue;
+        }
+
+        size_t index = Promotion::BinarySearch<Replacement, &Replacement::Offset>(agg->Replacements, lcl->GetLclOffs());
+        if ((ssize_t)index < 0)
+        {
+            continue;
+        }
+
+        Replacement& rep = agg->Replacements[index];
+        if (rep.NeedsReadBack)
+        {
+            JITDUMP("Reading back replacement V%02u.[%03u..%03u) -> V%02u before [%06u]:\n", agg->LclNum, rep.Offset,
+                    rep.Offset + genTypeSize(rep.AccessType), rep.LclNum, Compiler::dspTreeID(stmt->GetRootNode()));
+
+            GenTree*   readBack = Promotion::CreateReadBack(m_compiler, agg->LclNum, rep);
+            Statement* stmt     = m_compiler->fgNewStmtFromTree(readBack);
+            DISPSTMT(stmt);
+            m_compiler->fgInsertStmtBefore(m_currentBlock, m_currentStmt, stmt);
+            ClearNeedsReadBack(rep);
+        }
+    }
 }
 
 //------------------------------------------------------------------------
@@ -1938,6 +2001,69 @@ Compiler::fgWalkResult ReplaceVisitor::PostOrderVisit(GenTree** use, GenTree* us
 }
 
 //------------------------------------------------------------------------
+// SetNeedsWriteBack:
+//   Track that a replacement is more up-to-date in the field local than the
+//   struct local.
+//
+// Remarks:
+//   This is usually the case since we generally always keep a field's value in
+//   its created primitive local.
+//
+void ReplaceVisitor::SetNeedsWriteBack(Replacement& rep)
+{
+    rep.NeedsWriteBack = true;
+    assert(!rep.NeedsReadBack);
+}
+
+//------------------------------------------------------------------------
+// ClearNeedsWriteBack:
+//   Track that a replacement is not is more up-to-date in the field local than
+//   the struct local.
+//
+void ReplaceVisitor::ClearNeedsWriteBack(Replacement& rep)
+{
+    rep.NeedsWriteBack = false;
+}
+
+//------------------------------------------------------------------------
+// SetNeedsReadBack:
+//   Track that a replacement is more up-to-date in the struct local than the
+//   field local.
+//
+// Remarks:
+//   This occurs after the struct local is assigned in a way that cannot be
+//   decomposed directly into assignments to field locals; for example because
+//   it is passed as a retbuf.
+//
+void ReplaceVisitor::SetNeedsReadBack(Replacement& rep)
+{
+    if (rep.NeedsReadBack)
+    {
+        return;
+    }
+
+    rep.NeedsReadBack = true;
+    m_numPendingReadBacks++;
+}
+
+//------------------------------------------------------------------------
+// ClearNeedsReadBack:
+//   Track that a replacement is not more up-to-date in the struct local than
+//   the field local.
+//
+void ReplaceVisitor::ClearNeedsReadBack(Replacement& rep)
+{
+    if (!rep.NeedsReadBack)
+    {
+        return;
+    }
+
+    assert(m_numPendingReadBacks > 0);
+    rep.NeedsReadBack = false;
+    m_numPendingReadBacks--;
+}
+
+//------------------------------------------------------------------------
 // InsertMidTreeReadBacksIfNecessary:
 //   If necessary, insert IR to read back all replacements before the specified use.
 //
@@ -1960,7 +2086,7 @@ Compiler::fgWalkResult ReplaceVisitor::PostOrderVisit(GenTree** use, GenTree* us
 //
 GenTree** ReplaceVisitor::InsertMidTreeReadBacksIfNecessary(GenTree** use)
 {
-    if (!m_hasPendingReadBacks || !m_compiler->ehBlockHasExnFlowDsc(m_currentBlock))
+    if ((m_numPendingReadBacks == 0) || !m_compiler->ehBlockHasExnFlowDsc(m_currentBlock))
     {
         return use;
     }
@@ -1995,7 +2121,7 @@ GenTree** ReplaceVisitor::InsertMidTreeReadBacksIfNecessary(GenTree** use)
 
             JITDUMP("  V%02.[%03u..%03u) -> V%02u\n", agg->LclNum, rep.Offset, genTypeSize(rep.AccessType), rep.LclNum);
 
-            rep.NeedsReadBack = false;
+            ClearNeedsReadBack(rep);
             GenTree* readBack = Promotion::CreateReadBack(m_compiler, agg->LclNum, rep);
             *use =
                 m_compiler->gtNewOperNode(GT_COMMA, (*use)->IsValue() ? (*use)->TypeGet() : TYP_VOID, readBack, *use);
@@ -2004,7 +2130,7 @@ GenTree** ReplaceVisitor::InsertMidTreeReadBacksIfNecessary(GenTree** use)
         }
     }
 
-    m_hasPendingReadBacks = false;
+    assert(m_numPendingReadBacks == 0);
     return use;
 }
 
@@ -2185,17 +2311,22 @@ void ReplaceVisitor::ReplaceLocal(GenTree** use, GenTree* user)
 
     if (isDef)
     {
-        rep.NeedsWriteBack = true;
-        rep.NeedsReadBack  = false;
+        ClearNeedsReadBack(rep);
+        SetNeedsWriteBack(rep);
     }
     else if (rep.NeedsReadBack)
     {
+        // This is an uncommon case -- typically all readbacks are handled in
+        // ReplaceVisitor::StartStatement. This case is still needed to handle
+        // the situation where the readback was marked previously in this tree
+        // (e.g. due to a COMMA).
+
         JITDUMP("  ..needs a read back\n");
         *use = m_compiler->gtNewOperNode(GT_COMMA, (*use)->TypeGet(),
                                          Promotion::CreateReadBack(m_compiler, lclNum, rep), *use);
-        rep.NeedsReadBack = false;
+        ClearNeedsReadBack(rep);
 
-        // TODO-CQ: Local copy prop does not take into account that the
+        // TODO: Local copy prop does not take into account that the
         // uses of LCL_VAR occur at the user, which means it may introduce
         // illegally overlapping lifetimes, such as:
         //
@@ -2291,8 +2422,8 @@ void ReplaceVisitor::WriteBackBefore(GenTree** use, unsigned lcl, unsigned offs,
             *use = comma;
             use  = &comma->gtOp2;
 
-            rep.NeedsWriteBack = false;
-            m_madeChanges      = true;
+            ClearNeedsWriteBack(rep);
+            m_madeChanges = true;
         }
 
         index++;
@@ -2311,6 +2442,11 @@ void ReplaceVisitor::WriteBackBefore(GenTree** use, unsigned lcl, unsigned offs,
 //
 void ReplaceVisitor::MarkForReadBack(GenTreeLclVarCommon* lcl, unsigned size DEBUGARG(const char* reason))
 {
+    // We currently do not handle readbacks marked within a QMARK arm, but we
+    // never create this case and we expect to expand QMARKs in an earlier pass
+    // in the (relative) near future.
+    assert(m_compiler->fgGetTopLevelQmark(m_currentStmt->GetRootNode()) == nullptr);
+
     if (m_aggregates[lcl->GetLclNum()] == nullptr)
     {
         return;
@@ -2351,12 +2487,11 @@ void ReplaceVisitor::MarkForReadBack(GenTreeLclVarCommon* lcl, unsigned size DEB
         }
         else
         {
-            rep.NeedsReadBack     = true;
-            m_hasPendingReadBacks = true;
+            SetNeedsReadBack(rep);
             JITDUMP("  V%02u (%s) marked\n", rep.LclNum, rep.Description);
         }
 
-        rep.NeedsWriteBack = false;
+        ClearNeedsWriteBack(rep);
 
         index++;
     } while ((index < replacements.size()) && (replacements[index].Offset < end));
@@ -2443,6 +2578,7 @@ PhaseStatus Promotion::Run()
         for (Statement* stmt : bb->Statements())
         {
             DISPSTMT(stmt);
+
             replacer.StartStatement(stmt);
             replacer.WalkTree(stmt->GetRootNodePointer(), nullptr);
 
index dc69eec..323fae4 100644 (file)
@@ -257,7 +257,7 @@ class ReplaceVisitor : public GenTreeVisitor<ReplaceVisitor>
     jitstd::vector<AggregateInfo*>& m_aggregates;
     PromotionLiveness*              m_liveness;
     bool                            m_madeChanges         = false;
-    bool                            m_hasPendingReadBacks = false;
+    unsigned                        m_numPendingReadBacks = 0;
     bool                            m_mayHaveForwardSub   = false;
     Statement*                      m_currentStmt         = nullptr;
     BasicBlock*                     m_currentBlock        = nullptr;
@@ -287,17 +287,16 @@ public:
 
     void StartBlock(BasicBlock* block);
     void EndBlock();
-
-    void StartStatement(Statement* stmt)
-    {
-        m_currentStmt       = stmt;
-        m_madeChanges       = false;
-        m_mayHaveForwardSub = false;
-    }
+    void StartStatement(Statement* stmt);
 
     fgWalkResult PostOrderVisit(GenTree** use, GenTree* user);
 
 private:
+    void SetNeedsWriteBack(Replacement& rep);
+    void ClearNeedsWriteBack(Replacement& rep);
+    void SetNeedsReadBack(Replacement& rep);
+    void ClearNeedsReadBack(Replacement& rep);
+
     GenTree** InsertMidTreeReadBacksIfNecessary(GenTree** use);
     void ReadBackAfterCall(GenTreeCall* call, GenTree* user);
     bool IsPromotedStructLocalDying(GenTreeLclVarCommon* structLcl);
index 381de63..4bcf8a6 100644 (file)
@@ -427,8 +427,8 @@ private:
                 statements->AddStatement(store);
             }
 
-            entry.ToReplacement->NeedsWriteBack = true;
-            entry.ToReplacement->NeedsReadBack  = false;
+            m_replacer->ClearNeedsReadBack(*entry.ToReplacement);
+            m_replacer->SetNeedsWriteBack(*entry.ToReplacement);
         }
 
         RemainderStrategy remainderStrategy = DetermineRemainderStrategy(deaths);
@@ -520,7 +520,7 @@ private:
                         // The loop below will skip these replacements as an
                         // optimization if it is going to copy the struct
                         // anyway.
-                        rep->NeedsWriteBack = false;
+                        m_replacer->ClearNeedsWriteBack(*rep);
                     }
                 }
             }
@@ -690,8 +690,8 @@ private:
 
             if (entry.ToReplacement != nullptr)
             {
-                entry.ToReplacement->NeedsWriteBack = true;
-                entry.ToReplacement->NeedsReadBack  = false;
+                m_replacer->ClearNeedsReadBack(*entry.ToReplacement);
+                m_replacer->SetNeedsWriteBack(*entry.ToReplacement);
             }
 
             if (CanSkipEntry(entry, dstDeaths, remainderStrategy DEBUGARG(/* dump */ true)))
@@ -1096,12 +1096,12 @@ void ReplaceVisitor::HandleStructStore(GenTree** use, GenTree* user)
                     // We accomplish this by an initial write back, the struct copy, followed by a later read back.
                     // TODO-CQ: This is expensive and unreflected in heuristics, but it is also very rare.
                     result.AddStatement(Promotion::CreateWriteBack(m_compiler, dstLcl->GetLclNum(), *dstFirstRep));
-                    dstFirstRep->NeedsWriteBack = false;
+                    ClearNeedsWriteBack(*dstFirstRep);
                 }
 
+                SetNeedsReadBack(*dstFirstRep);
+
                 plan.MarkNonRemainderUseOfStructLocal();
-                dstFirstRep->NeedsReadBack = true;
-                m_hasPendingReadBacks      = true;
                 dstFirstRep++;
             }
 
@@ -1116,12 +1116,12 @@ void ReplaceVisitor::HandleStructStore(GenTree** use, GenTree* user)
                     if (dstLastRep->NeedsWriteBack)
                     {
                         result.AddStatement(Promotion::CreateWriteBack(m_compiler, dstLcl->GetLclNum(), *dstLastRep));
-                        dstLastRep->NeedsWriteBack = false;
+                        ClearNeedsWriteBack(*dstLastRep);
                     }
 
+                    SetNeedsReadBack(*dstLastRep);
+
                     plan.MarkNonRemainderUseOfStructLocal();
-                    dstLastRep->NeedsReadBack = true;
-                    m_hasPendingReadBacks     = true;
                     dstEndRep--;
                 }
             }
@@ -1140,7 +1140,7 @@ void ReplaceVisitor::HandleStructStore(GenTree** use, GenTree* user)
                 if (srcFirstRep->NeedsWriteBack)
                 {
                     result.AddStatement(Promotion::CreateWriteBack(m_compiler, srcLcl->GetLclNum(), *srcFirstRep));
-                    srcFirstRep->NeedsWriteBack = false;
+                    ClearNeedsWriteBack(*srcFirstRep);
                 }
 
                 srcFirstRep++;
@@ -1157,7 +1157,7 @@ void ReplaceVisitor::HandleStructStore(GenTree** use, GenTree* user)
                     if (srcLastRep->NeedsWriteBack)
                     {
                         result.AddStatement(Promotion::CreateWriteBack(m_compiler, srcLcl->GetLclNum(), *srcLastRep));
-                        srcLastRep->NeedsWriteBack = false;
+                        ClearNeedsWriteBack(*srcLastRep);
                     }
 
                     srcEndRep--;
@@ -1308,8 +1308,8 @@ void ReplaceVisitor::InitFields(GenTreeLclVarCommon* dstStore,
                     rep->Description);
 
             // We will need to read this one back after initing the struct.
-            rep->NeedsWriteBack = false;
-            rep->NeedsReadBack  = true;
+            ClearNeedsWriteBack(*rep);
+            SetNeedsReadBack(*rep);
             plan->MarkNonRemainderUseOfStructLocal();
             continue;
         }
@@ -1395,7 +1395,7 @@ void ReplaceVisitor::CopyBetweenFields(GenTree*                    store,
 
             assert(srcLcl != nullptr);
             statements->AddStatement(Promotion::CreateReadBack(m_compiler, srcLcl->GetLclNum(), *srcRep));
-            srcRep->NeedsReadBack = false;
+            ClearNeedsReadBack(*srcRep);
             assert(!srcRep->NeedsWriteBack);
         }
 
diff --git a/src/tests/JIT/Directed/physicalpromotion/readbackbeforeqmark.cs b/src/tests/JIT/Directed/physicalpromotion/readbackbeforeqmark.cs
new file mode 100644 (file)
index 0000000..9046a2e
--- /dev/null
@@ -0,0 +1,52 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+//
+
+using System;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using Xunit;
+
+public class Runtime_87508
+{
+    [Fact]
+    public static int TestEntryPoint()
+    {
+        return new Runtime_87508().WriteBlock("1234");
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    public int WriteBlock(string source)
+    {
+        ReadOnlySpan<char> line = GetNextLine();
+        Trash();
+        // Unrolling of this creates a QMARK with LCL_FLD uses in the arms. The
+        // JIT must be careful to read the fields of the promoted 'line' back
+        // before the conditional nature of the QMARK.
+        if (line.StartsWith("{"))
+        {
+            Console.WriteLine("FAIL: succeeded");
+            return -1;
+        }
+
+        if (!Unsafe.AreSame(ref MemoryMarshal.GetReference(line), ref MemoryMarshal.GetArrayDataReference(_emptyChars)))
+        {
+            Console.WriteLine("FAIL: References were not equal");
+            return -2;
+        }
+
+        Console.WriteLine("PASS");
+        return 100;
+    }
+
+    private char[] _emptyChars = new char[0];
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    private ReadOnlySpan<char> GetNextLine()
+    {
+        return _emptyChars;
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    private nint Trash() => 0;
+}
diff --git a/src/tests/JIT/Directed/physicalpromotion/readbackbeforeqmark.csproj b/src/tests/JIT/Directed/physicalpromotion/readbackbeforeqmark.csproj
new file mode 100644 (file)
index 0000000..58658b9
--- /dev/null
@@ -0,0 +1,10 @@
+<Project Sdk="Microsoft.NET.Sdk">
+  <PropertyGroup>
+    <Optimize>True</Optimize>
+  </PropertyGroup>
+  <ItemGroup>
+    <Compile Include="$(MSBuildProjectName).cs" />
+    <CLRTestEnvironmentVariable Include="DOTNET_JitEnablePhysicalPromotion" Value="1" />
+    <CLRTestEnvironmentVariable Include="DOTNET_JitStressModeNames" Value="STRESS_NO_OLD_PROMOTION STRESS_PHYSICAL_PROMOTION_COST" />
+  </ItemGroup>
+</Project>
\ No newline at end of file