JIT: Have physical promotion insert read backs before possible implicit control flow...
authorJakob Botsch Nielsen <Jakob.botsch.nielsen@gmail.com>
Fri, 26 May 2023 19:15:33 +0000 (21:15 +0200)
committerGitHub <noreply@github.com>
Fri, 26 May 2023 19:15:33 +0000 (21:15 +0200)
Physical promotion relies on being able to read back any promoted field
that is fresher in its struct local before control flows to any
successor block. This was failing to take implicit control flow into
account.

Fix #86498

src/coreclr/jit/promotion.cpp
src/coreclr/jit/promotion.h
src/coreclr/jit/promotiondecomposition.cpp
src/coreclr/jit/promotionliveness.cpp
src/tests/JIT/Regression/JitBlue/Runtime_86498/Runtime_86498.cs [new file with mode: 0644]
src/tests/JIT/Regression/JitBlue/Runtime_86498/Runtime_86498.csproj [new file with mode: 0644]

index 1bf1487..5aa4bb2 100644 (file)
@@ -1150,10 +1150,68 @@ GenTree* Promotion::CreateReadBack(Compiler* compiler, unsigned structLclNum, co
     return store;
 }
 
+//------------------------------------------------------------------------
+// EndBlock:
+//   Handle reaching the end of the currently started block by preparing
+//   internal state for upcoming basic blocks, and inserting any necessary
+//   readbacks.
+//
+// Remarks:
+//   We currently expect all fields to be most up-to-date in their field locals
+//   at the beginning of every basic block. That means all replacements should
+//   have Replacement::NeedsReadBack == false and Replacement::NeedsWriteBack
+//   == true at the beginning of every block. This function makes it so that is
+//   the case.
+//
+void ReplaceVisitor::EndBlock()
+{
+    for (AggregateInfo* agg : m_aggregates)
+    {
+        if (agg == nullptr)
+        {
+            continue;
+        }
+
+        for (size_t i = 0; i < agg->Replacements.size(); i++)
+        {
+            Replacement& rep = agg->Replacements[i];
+            assert(!rep.NeedsReadBack || !rep.NeedsWriteBack);
+            if (rep.NeedsReadBack)
+            {
+                if (m_liveness->IsReplacementLiveOut(m_currentBlock, agg->LclNum, (unsigned)i))
+                {
+                    JITDUMP("Reading back replacement V%02u.[%03u..%03u) -> V%02u near the end of " FMT_BB ":\n",
+                            agg->LclNum, rep.Offset, rep.Offset + genTypeSize(rep.AccessType), rep.LclNum,
+                            m_currentBlock->bbNum);
+
+                    GenTree*   readBack = Promotion::CreateReadBack(m_compiler, agg->LclNum, rep);
+                    Statement* stmt     = m_compiler->fgNewStmtFromTree(readBack);
+                    DISPSTMT(stmt);
+                    m_compiler->fgInsertStmtNearEnd(m_currentBlock, stmt);
+                }
+                else
+                {
+                    JITDUMP("Skipping reading back dead replacement V%02u.[%03u..%03u) -> V%02u near the end of " FMT_BB
+                            "\n",
+                            agg->LclNum, rep.Offset, rep.Offset + genTypeSize(rep.AccessType), rep.LclNum,
+                            m_currentBlock->bbNum);
+                }
+                rep.NeedsReadBack = false;
+            }
+
+            rep.NeedsWriteBack = true;
+        }
+    }
+
+    m_hasPendingReadBacks = false;
+}
+
 Compiler::fgWalkResult ReplaceVisitor::PostOrderVisit(GenTree** use, GenTree* user)
 {
     GenTree* tree = *use;
 
+    use = InsertMidTreeReadBacksIfNecessary(use);
+
     if (tree->OperIsStore())
     {
         if (tree->OperIsLocalStore())
@@ -1192,6 +1250,80 @@ Compiler::fgWalkResult ReplaceVisitor::PostOrderVisit(GenTree** use, GenTree* us
 }
 
 //------------------------------------------------------------------------
+// InsertMidTreeReadBacksIfNecessary:
+//   If necessary, insert IR to read back all replacements before the specified use.
+//
+// Parameters:
+//   use - The use
+//
+// Returns:
+//   New use pointing to the old tree.
+//
+// Remarks:
+//   When a struct field is most up-to-date in its struct local it is marked to
+//   need a read back. We then need to decide when to insert IR to read it back
+//   to its field local.
+//
+//   We normally do this before the first use of the field we find, or before
+//   we transfer control to any successor. This method handles the case of
+//   implicit control flow related to EH; when this basic block is in a
+//   try-region (or filter block) and we find a tree that may throw it eagerly
+//   inserts pending readbacks.
+//
+GenTree** ReplaceVisitor::InsertMidTreeReadBacksIfNecessary(GenTree** use)
+{
+    if (!m_hasPendingReadBacks || !m_compiler->ehBlockHasExnFlowDsc(m_currentBlock))
+    {
+        return use;
+    }
+
+    if (((*use)->gtFlags & (GTF_EXCEPT | GTF_CALL)) == 0)
+    {
+        assert(!(*use)->OperMayThrow(m_compiler));
+        return use;
+    }
+
+    if (!(*use)->OperMayThrow(m_compiler))
+    {
+        return use;
+    }
+
+    JITDUMP("Reading back pending replacements before tree with possible exception side effect inside block in try "
+            "region\n");
+
+    for (AggregateInfo* agg : m_aggregates)
+    {
+        if (agg == nullptr)
+        {
+            continue;
+        }
+
+        for (Replacement& rep : agg->Replacements)
+        {
+            // TODO-CQ: We should ensure we do not mark dead fields as
+            // requiring readback. Currently it is handled by querying liveness
+            // as part of end-of-block readback insertion, but for these
+            // mid-tree readbacks we cannot query liveness information for
+            // arbitrary locals.
+            if (!rep.NeedsReadBack)
+            {
+                continue;
+            }
+
+            rep.NeedsReadBack = false;
+            GenTree* readBack = Promotion::CreateReadBack(m_compiler, agg->LclNum, rep);
+            *use =
+                m_compiler->gtNewOperNode(GT_COMMA, (*use)->IsValue() ? (*use)->TypeGet() : TYP_VOID, readBack, *use);
+            use           = &(*use)->AsOp()->gtOp2;
+            m_madeChanges = true;
+        }
+    }
+
+    m_hasPendingReadBacks = false;
+    return use;
+}
+
+//------------------------------------------------------------------------
 // LoadStoreAroundCall:
 //   Handle a call that may involve struct local arguments and that may
 //   pass a struct local with replacements as the retbuf.
@@ -1237,7 +1369,10 @@ void ReplaceVisitor::LoadStoreAroundCall(GenTreeCall* call, GenTree* user)
         GenTreeLclVarCommon* retBufLcl = retBufArg->GetNode()->AsLclVarCommon();
         unsigned             size      = m_compiler->typGetObjLayout(call->gtRetClsHnd)->GetSize();
 
-        MarkForReadBack(retBufLcl->GetLclNum(), retBufLcl->GetLclOffs(), size);
+        if (MarkForReadBack(retBufLcl->GetLclNum(), retBufLcl->GetLclOffs(), size))
+        {
+            JITDUMP("Retbuf has replacements that were marked for read back\n");
+        }
     }
 }
 
@@ -1269,10 +1404,9 @@ bool ReplaceVisitor::IsPromotedStructLocalDying(GenTreeLclVarCommon* lcl)
     }
 
     AggregateInfo* agg = m_aggregates[lcl->GetLclNum()];
-
-    for (size_t i = 0; i < agg->Replacements.size(); i++)
+    for (Replacement& rep : agg->Replacements)
     {
-        if (agg->Replacements[i].NeedsReadBack)
+        if (rep.NeedsReadBack)
         {
             return false;
         }
@@ -1487,11 +1621,11 @@ void ReplaceVisitor::WriteBackBefore(GenTree** use, unsigned lcl, unsigned offs,
 //   offs         - The starting offset of the range in the struct local that needs to be read back from.
 //   size         - The size of the range
 //
-void ReplaceVisitor::MarkForReadBack(unsigned lcl, unsigned offs, unsigned size)
+bool ReplaceVisitor::MarkForReadBack(unsigned lcl, unsigned offs, unsigned size)
 {
     if (m_aggregates[lcl] == nullptr)
     {
-        return;
+        return false;
     }
 
     jitstd::vector<Replacement>& replacements = m_aggregates[lcl]->Replacements;
@@ -1506,17 +1640,20 @@ void ReplaceVisitor::MarkForReadBack(unsigned lcl, unsigned offs, unsigned size)
         }
     }
 
-    bool     result = false;
-    unsigned end    = offs + size;
+    bool     any = false;
+    unsigned end = offs + size;
     while ((index < replacements.size()) && (replacements[index].Offset < end))
     {
-        result           = true;
+        any              = true;
         Replacement& rep = replacements[index];
         assert(rep.Overlaps(offs, size));
-        rep.NeedsReadBack  = true;
-        rep.NeedsWriteBack = false;
+        rep.NeedsReadBack     = true;
+        rep.NeedsWriteBack    = false;
+        m_hasPendingReadBacks = true;
         index++;
     }
+
+    return any;
 }
 
 //------------------------------------------------------------------------
@@ -1631,10 +1768,12 @@ PhaseStatus Promotion::Run()
     ReplaceVisitor replacer(this, aggregates, &liveness);
     for (BasicBlock* bb : m_compiler->Blocks())
     {
+        replacer.StartBlock(bb);
+
         for (Statement* stmt : bb->Statements())
         {
             DISPSTMT(stmt);
-            replacer.Reset();
+            replacer.StartStatement();
             replacer.WalkTree(stmt->GetRootNodePointer(), nullptr);
 
             if (replacer.MadeChanges())
@@ -1646,42 +1785,7 @@ PhaseStatus Promotion::Run()
             }
         }
 
-        for (unsigned i = 0; i < numLocals; i++)
-        {
-            if (aggregates[i] == nullptr)
-            {
-                continue;
-            }
-
-            for (size_t j = 0; j < aggregates[i]->Replacements.size(); j++)
-            {
-                Replacement& rep = aggregates[i]->Replacements[j];
-                assert(!rep.NeedsReadBack || !rep.NeedsWriteBack);
-                if (rep.NeedsReadBack)
-                {
-                    if (liveness.IsReplacementLiveOut(bb, i, (unsigned)j))
-                    {
-                        JITDUMP("Reading back replacement V%02u.[%03u..%03u) -> V%02u near the end of " FMT_BB ":\n", i,
-                                rep.Offset, rep.Offset + genTypeSize(rep.AccessType), rep.LclNum, bb->bbNum);
-
-                        GenTree*   readBack = CreateReadBack(m_compiler, i, rep);
-                        Statement* stmt     = m_compiler->fgNewStmtFromTree(readBack);
-                        DISPSTMT(stmt);
-                        m_compiler->fgInsertStmtNearEnd(bb, stmt);
-                    }
-                    else
-                    {
-                        JITDUMP(
-                            "Skipping reading back dead replacement V%02u.[%03u..%03u) -> V%02u near the end of " FMT_BB
-                            "\n",
-                            i, rep.Offset, rep.Offset + genTypeSize(rep.AccessType), rep.LclNum, bb->bbNum);
-                    }
-                    rep.NeedsReadBack = false;
-                }
-
-                rep.NeedsWriteBack = true;
-            }
-        }
+        replacer.EndBlock();
     }
 
     // Insert initial IR to read arguments/OSR locals into replacement locals,
index 895058e..7f64613 100644 (file)
@@ -244,10 +244,11 @@ class DecompositionPlan;
 
 class ReplaceVisitor : public GenTreeVisitor<ReplaceVisitor>
 {
-    Promotion*                      m_prom;
     jitstd::vector<AggregateInfo*>& m_aggregates;
     PromotionLiveness*              m_liveness;
-    bool                            m_madeChanges = false;
+    bool                            m_madeChanges         = false;
+    bool                            m_hasPendingReadBacks = false;
+    BasicBlock*                     m_currentBlock        = nullptr;
 
 public:
     enum
@@ -257,7 +258,7 @@ public:
     };
 
     ReplaceVisitor(Promotion* prom, jitstd::vector<AggregateInfo*>& aggregates, PromotionLiveness* liveness)
-        : GenTreeVisitor(prom->m_compiler), m_prom(prom), m_aggregates(aggregates), m_liveness(liveness)
+        : GenTreeVisitor(prom->m_compiler), m_aggregates(aggregates), m_liveness(liveness)
     {
     }
 
@@ -266,7 +267,14 @@ public:
         return m_madeChanges;
     }
 
-    void Reset()
+    void StartBlock(BasicBlock* block)
+    {
+        m_currentBlock = block;
+    }
+
+    void EndBlock();
+
+    void StartStatement()
     {
         m_madeChanges = false;
     }
@@ -274,12 +282,13 @@ public:
     fgWalkResult PostOrderVisit(GenTree** use, GenTree* user);
 
 private:
+    GenTree** InsertMidTreeReadBacksIfNecessary(GenTree** use);
     void LoadStoreAroundCall(GenTreeCall* call, GenTree* user);
     bool IsPromotedStructLocalDying(GenTreeLclVarCommon* structLcl);
     void ReplaceLocal(GenTree** use, GenTree* user);
     void StoreBeforeReturn(GenTreeUnOp* ret);
     void WriteBackBefore(GenTree** use, unsigned lcl, unsigned offs, unsigned size);
-    void MarkForReadBack(unsigned lcl, unsigned offs, unsigned size);
+    bool MarkForReadBack(unsigned lcl, unsigned offs, unsigned size);
 
     void HandleStore(GenTree** use, GenTree* user);
     bool OverlappingReplacements(GenTreeLclVarCommon* lcl,
index fcfd81a..d0df8b9 100644 (file)
@@ -1033,6 +1033,7 @@ void ReplaceVisitor::HandleStore(GenTree** use, GenTree* user)
 
                 plan.MarkNonRemainderUseOfStructLocal();
                 dstFirstRep->NeedsReadBack = true;
+                m_hasPendingReadBacks      = true;
                 dstFirstRep++;
             }
 
@@ -1052,6 +1053,7 @@ void ReplaceVisitor::HandleStore(GenTree** use, GenTree* user)
 
                     plan.MarkNonRemainderUseOfStructLocal();
                     dstLastRep->NeedsReadBack = true;
+                    m_hasPendingReadBacks     = true;
                     dstEndRep--;
                 }
             }
@@ -1122,7 +1124,10 @@ void ReplaceVisitor::HandleStore(GenTree** use, GenTree* user)
         {
             GenTreeLclVarCommon* lclStore = store->AsLclVarCommon();
             unsigned             size     = lclStore->GetLayout(m_compiler)->GetSize();
-            MarkForReadBack(lclStore->GetLclNum(), lclStore->GetLclOffs(), size);
+            if (MarkForReadBack(lclStore->GetLclNum(), lclStore->GetLclOffs(), size))
+            {
+                JITDUMP("Marked store destination replacements to be read back (could not decompose this store)\n");
+            }
         }
     }
 }
index a345637..71ac5a9 100644 (file)
@@ -74,15 +74,14 @@ void PromotionLiveness::Run()
 {
     m_structLclToTrackedIndex = new (m_compiler, CMK_Promotion) unsigned[m_aggregates.size()]{};
     unsigned trackedIndex     = 0;
-    for (size_t lclNum = 0; lclNum < m_aggregates.size(); lclNum++)
+    for (AggregateInfo* agg : m_aggregates)
     {
-        AggregateInfo* agg = m_aggregates[lclNum];
         if (agg == nullptr)
         {
             continue;
         }
 
-        m_structLclToTrackedIndex[lclNum] = trackedIndex;
+        m_structLclToTrackedIndex[agg->LclNum] = trackedIndex;
         // TODO: We need a scalability limit on these, we cannot always track
         // the remainder and all fields.
         // Remainder.
@@ -93,7 +92,7 @@ void PromotionLiveness::Run()
 #ifdef DEBUG
         // Mark the struct local (remainder) and fields as tracked for DISPTREE to properly
         // show last use information.
-        m_compiler->lvaGetDesc((unsigned)lclNum)->lvTrackedWithoutIndex = true;
+        m_compiler->lvaGetDesc(agg->LclNum)->lvTrackedWithoutIndex = true;
         for (size_t i = 0; i < agg->Replacements.size(); i++)
         {
             m_compiler->lvaGetDesc(agg->Replacements[i].LclNum)->lvTrackedWithoutIndex = true;
@@ -830,9 +829,8 @@ void PromotionLiveness::DumpVarSet(BitVec set, BitVec allVars)
     printf("{");
 
     const char* sep = "";
-    for (size_t i = 0; i < m_aggregates.size(); i++)
+    for (AggregateInfo* agg : m_aggregates)
     {
-        AggregateInfo* agg = m_aggregates[i];
         if (agg == nullptr)
         {
             continue;
@@ -840,18 +838,18 @@ void PromotionLiveness::DumpVarSet(BitVec set, BitVec allVars)
 
         for (size_t j = 0; j <= agg->Replacements.size(); j++)
         {
-            unsigned index = (unsigned)(m_structLclToTrackedIndex[i] + j);
+            unsigned index = (unsigned)(m_structLclToTrackedIndex[agg->LclNum] + j);
 
             if (BitVecOps::IsMember(m_bvTraits, set, index))
             {
                 if (j == 0)
                 {
-                    printf("%sV%02u(remainder)", sep, (unsigned)i);
+                    printf("%sV%02u(remainder)", sep, agg->LclNum);
                 }
                 else
                 {
                     const Replacement& rep = agg->Replacements[j - 1];
-                    printf("%sV%02u.[%03u..%03u)", sep, (unsigned)i, rep.Offset,
+                    printf("%sV%02u.[%03u..%03u)", sep, agg->LclNum, rep.Offset,
                            rep.Offset + genTypeSize(rep.AccessType));
                 }
                 sep = " ";
diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_86498/Runtime_86498.cs b/src/tests/JIT/Regression/JitBlue/Runtime_86498/Runtime_86498.cs
new file mode 100644 (file)
index 0000000..5542a74
--- /dev/null
@@ -0,0 +1,49 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime.CompilerServices;
+using Xunit;
+
+public class Runtime_86498
+{
+    [Fact]
+    public static int Test()
+    {
+        Foo f = new();
+        try
+        {
+            f.X = 15;
+            f.Y = 20;
+            f.X += f.Y;
+            f.Y *= f.X;
+
+            // f will be physically promoted and will require a read back after this call.
+            // However, there is implicit control flow happening that the read back should happen before.
+            f = Call(f);
+            ThrowException();
+            return -1;
+        }
+        catch (Exception ex)
+        {
+            return f.X + f.Y;
+        }
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    private static Foo Call(Foo f)
+    {
+        return new Foo { X = 75, Y = 25 };
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    private static void ThrowException()
+    {
+        throw new Exception();
+    }
+
+    private struct Foo
+    {
+        public short X, Y;
+    }
+}
diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_86498/Runtime_86498.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_86498/Runtime_86498.csproj
new file mode 100644 (file)
index 0000000..85f04c1
--- /dev/null
@@ -0,0 +1,9 @@
+<Project Sdk="Microsoft.NET.Sdk">
+  <PropertyGroup>
+    <Optimize>True</Optimize>
+  </PropertyGroup>
+  <ItemGroup>
+    <Compile Include="$(MSBuildProjectName).cs" />
+    <CLRTestEnvironmentVariable Include="DOTNET_JitStressModeNames" Value="STRESS_PHYSICAL_PROMOTION STRESS_PHYSICAL_PROMOTION_COST STRESS_NO_OLD_PROMOTION" />
+  </ItemGroup>
+</Project>
\ No newline at end of file