From 1b80ccba0f9bb86eabeeb0c32876eae2088eafb8 Mon Sep 17 00:00:00 2001 From: Mircea Trofin Date: Sun, 5 Feb 2023 17:16:46 -0800 Subject: [PATCH] [mlgo][regalloc] Handle training case when no regalloc happens. There's an early-exit case for regalloc when we don't even get a chance to ask for an advisor (priority or eviction), and switch the context. Then, when we want to log the reward for that function (==the one with the early exit case), we hit the error case where the function's name doesn't match the last-seen context. There are a few possible fixes, one would be to just switch context when output-ing the reward, which would be correct. This patch opts for the alternative where we check any loging happened in the first place - just to re-validate that no function would have been regaloc-ed without first log-ing its reward. Differential Revision: https://reviews.llvm.org/D143359 --- llvm/include/llvm/Analysis/Utils/TrainingLogger.h | 8 +++++- llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp | 2 +- llvm/lib/CodeGen/MLRegallocPriorityAdvisor.cpp | 2 +- llvm/test/CodeGen/MLRegalloc/empty-log.ll | 33 +++++++++++++++++++++++ 4 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 llvm/test/CodeGen/MLRegalloc/empty-log.ll diff --git a/llvm/include/llvm/Analysis/Utils/TrainingLogger.h b/llvm/include/llvm/Analysis/Utils/TrainingLogger.h index ef60189..57a1637 100644 --- a/llvm/include/llvm/Analysis/Utils/TrainingLogger.h +++ b/llvm/include/llvm/Analysis/Utils/TrainingLogger.h @@ -121,8 +121,14 @@ public: const std::string ¤tContext() const { return CurrentContext; } + /// Check if there is at least an observation for `currentContext()`. bool hasObservationInProgress() const { - return ObservationIDs.find(CurrentContext) != ObservationIDs.end(); + return hasAnyObservationForContext(CurrentContext); + } + + /// Check if there is at least an observation for the context `Ctx`. + bool hasAnyObservationForContext(StringRef Ctx) const { + return ObservationIDs.find(Ctx) != ObservationIDs.end(); } template void logReward(T Value) { diff --git a/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp b/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp index 0064e85d..b36a86a 100644 --- a/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp +++ b/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp @@ -474,7 +474,7 @@ public: void logRewardIfNeeded(const MachineFunction &MF, llvm::function_ref GetReward) override { - if (!Log) + if (!Log || !Log->hasAnyObservationForContext(MF.getName())) return; // The function pass manager would run all the function passes for a // function, so we assume the last context belongs to this function. If diff --git a/llvm/lib/CodeGen/MLRegallocPriorityAdvisor.cpp b/llvm/lib/CodeGen/MLRegallocPriorityAdvisor.cpp index 7a5345e..4227815 100644 --- a/llvm/lib/CodeGen/MLRegallocPriorityAdvisor.cpp +++ b/llvm/lib/CodeGen/MLRegallocPriorityAdvisor.cpp @@ -197,7 +197,7 @@ public: void logRewardIfNeeded(const MachineFunction &MF, llvm::function_ref GetReward) override { - if (!Log) + if (!Log || !Log->hasAnyObservationForContext(MF.getName())) return; // The function pass manager would run all the function passes for a // function, so we assume the last context belongs to this function. If diff --git a/llvm/test/CodeGen/MLRegalloc/empty-log.ll b/llvm/test/CodeGen/MLRegalloc/empty-log.ll new file mode 100644 index 0000000..951a8f6 --- /dev/null +++ b/llvm/test/CodeGen/MLRegalloc/empty-log.ll @@ -0,0 +1,33 @@ +; REQUIRES: have_tflite +; REQUIRES: x86_64-linux +; +; Check that we can log more than 1 function. +; +; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=development \ +; RUN: -regalloc-training-log=%t1 < %s +; RUN: FileCheck --input-file %t1 %s + +; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-priority-advisor=development \ +; RUN: -regalloc-priority-training-log=%t2 < %s +; RUN: FileCheck --input-file %t2 %s + +declare void @f(); + +define void @f1(i64 %lhs, i64 %rhs, i64* %addr) { + ret void +} + +define void @f2(i64 %lhs, i64 %rhs, i64* %addr) { + %sum = add i64 %lhs, %rhs + call void @f(); + store i64 %sum, i64* %addr + ret void +} + +define void @f3(i64 %lhs, i64 %rhs, i64* %addr) { + ret void +} + +; CHECK-NOT: {"context":"f1"} +; CHECK: {"context":"f2"} +; CHECK-NOT: {"context":"f3"} -- 2.7.4