--- /dev/null
+//===- RegAllocScore.cpp - evaluate regalloc policy quality ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// Calculate a measure of the register allocation policy quality. This is used
+/// to construct a reward for the training of the ML-driven allocation policy.
+/// Currently, the score is the sum of the machine basic block frequency-weighed
+/// number of loads, stores, copies, and remat instructions, each factored with
+/// a relative weight.
+//===----------------------------------------------------------------------===//
+
+#include "RegAllocScore.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Target/TargetMachine.h"
+#include <cassert>
+#include <cstdint>
+#include <numeric>
+#include <vector>
+
+using namespace llvm;
+cl::opt<double> CopyWeight("regalloc-copy-weight", cl::init(0.2), cl::Hidden);
+cl::opt<double> LoadWeight("regalloc-load-weight", cl::init(4.0), cl::Hidden);
+cl::opt<double> StoreWeight("regalloc-store-weight", cl::init(1.0), cl::Hidden);
+cl::opt<double> CheapRematWeight("regalloc-cheap-remat-weight", cl::init(0.2),
+ cl::Hidden);
+cl::opt<double> ExpensiveRematWeight("regalloc-expensive-remat-weight",
+ cl::init(1.0), cl::Hidden);
+#define DEBUG_TYPE "regalloc-score"
+
+RegAllocScore &RegAllocScore::operator+=(const RegAllocScore &Other) {
+ CopyCounts += Other.copyCounts();
+ LoadCounts += Other.loadCounts();
+ StoreCounts += Other.storeCounts();
+ LoadStoreCounts += Other.loadStoreCounts();
+ CheapRematCounts += Other.cheapRematCounts();
+ ExpensiveRematCounts += Other.expensiveRematCounts();
+ return *this;
+}
+
+bool RegAllocScore::operator==(const RegAllocScore &Other) const {
+ return copyCounts() == Other.copyCounts() &&
+ loadCounts() == Other.loadCounts() &&
+ storeCounts() == Other.storeCounts() &&
+ loadStoreCounts() == Other.loadStoreCounts() &&
+ cheapRematCounts() == Other.cheapRematCounts() &&
+ expensiveRematCounts() == Other.expensiveRematCounts();
+}
+
+bool RegAllocScore::operator!=(const RegAllocScore &Other) const {
+ return !(*this == Other);
+}
+
+double RegAllocScore::getScore() const {
+ double Ret = 0.0;
+ Ret += CopyWeight * copyCounts();
+ Ret += LoadWeight * loadCounts();
+ Ret += StoreWeight * storeCounts();
+ Ret += (LoadWeight + StoreWeight) * loadStoreCounts();
+ Ret += CheapRematWeight * cheapRematCounts();
+ Ret += ExpensiveRematWeight * expensiveRematCounts();
+
+ return Ret;
+}
+
+RegAllocScore
+llvm::calculateRegAllocScore(const MachineFunction &MF,
+ const MachineBlockFrequencyInfo &MBFI,
+ AAResults &AAResults) {
+ return calculateRegAllocScore(
+ MF,
+ [&](const MachineBasicBlock &MBB) {
+ return MBFI.getBlockFreqRelativeToEntryBlock(&MBB);
+ },
+ [&](const MachineInstr &MI) {
+ return MF.getSubtarget().getInstrInfo()->isTriviallyReMaterializable(
+ MI, &AAResults);
+ });
+}
+
+RegAllocScore llvm::calculateRegAllocScore(
+ const MachineFunction &MF,
+ llvm::function_ref<double(const MachineBasicBlock &)> GetBBFreq,
+ llvm::function_ref<bool(const MachineInstr &)>
+ IsTriviallyRematerializable) {
+ RegAllocScore Total;
+
+ for (const MachineBasicBlock &MBB : MF) {
+ double BlockFreqRelativeToEntrypoint = GetBBFreq(MBB);
+ RegAllocScore MBBScore;
+
+ for (const MachineInstr &MI : MBB) {
+ if (MI.isDebugInstr() || MI.isKill() || MI.isInlineAsm()) {
+ continue;
+ }
+ if (MI.isCopy()) {
+ MBBScore.onCopy(BlockFreqRelativeToEntrypoint);
+ } else if (IsTriviallyRematerializable(MI)) {
+ if (MI.getDesc().isAsCheapAsAMove()) {
+ MBBScore.onCheapRemat(BlockFreqRelativeToEntrypoint);
+ } else {
+ MBBScore.onExpensiveRemat(BlockFreqRelativeToEntrypoint);
+ }
+ } else if (MI.mayLoad() && MI.mayStore()) {
+ MBBScore.onLoadStore(BlockFreqRelativeToEntrypoint);
+ } else if (MI.mayLoad()) {
+ MBBScore.onLoad(BlockFreqRelativeToEntrypoint);
+ } else if (MI.mayStore()) {
+ MBBScore.onStore(BlockFreqRelativeToEntrypoint);
+ }
+ }
+ Total += MBBScore;
+ }
+ return Total;
+}
--- /dev/null
+//==- RegAllocScore.h - evaluate regalloc policy quality ----------*-C++-*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// Calculate a measure of the register allocation policy quality. This is used
+/// to construct a reward for the training of the ML-driven allocation policy.
+/// Currently, the score is the sum of the machine basic block frequency-weighed
+/// number of loads, stores, copies, and remat instructions, each factored with
+/// a relative weight.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_REGALLOCSCORE_H_
+#define LLVM_CODEGEN_REGALLOCSCORE_H_
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/Analysis/ProfileSummaryInfo.h"
+#include "llvm/Analysis/Utils/TFUtils.h"
+#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
+#include "llvm/IR/Module.h"
+#include <cassert>
+#include <cstdint>
+#include <limits>
+
+namespace llvm {
+
+/// Regalloc score.
+class RegAllocScore final {
+ double CopyCounts = 0.0;
+ double LoadCounts = 0.0;
+ double StoreCounts = 0.0;
+ double CheapRematCounts = 0.0;
+ double LoadStoreCounts = 0.0;
+ double ExpensiveRematCounts = 0.0;
+
+public:
+ RegAllocScore() = default;
+ RegAllocScore(const RegAllocScore &) = default;
+
+ double copyCounts() const { return CopyCounts; }
+ double loadCounts() const { return LoadCounts; }
+ double storeCounts() const { return StoreCounts; }
+ double loadStoreCounts() const { return LoadStoreCounts; }
+ double expensiveRematCounts() const { return ExpensiveRematCounts; }
+ double cheapRematCounts() const { return CheapRematCounts; }
+
+ void onCopy(double Freq) { CopyCounts += Freq; }
+ void onLoad(double Freq) { LoadCounts += Freq; }
+ void onStore(double Freq) { StoreCounts += Freq; }
+ void onLoadStore(double Freq) { LoadStoreCounts += Freq; }
+ void onExpensiveRemat(double Freq) { ExpensiveRematCounts += Freq; }
+ void onCheapRemat(double Freq) { CheapRematCounts += Freq; }
+
+ RegAllocScore &operator+=(const RegAllocScore &Other);
+ bool operator==(const RegAllocScore &Other) const;
+ bool operator!=(const RegAllocScore &Other) const;
+ double getScore() const;
+};
+
+/// Calculate a score. When comparing 2 scores for the same function but
+/// different policies, the better policy would have a smaller score.
+/// The implementation is the overload below (which is also easily unittestable)
+RegAllocScore calculateRegAllocScore(const MachineFunction &MF,
+ const MachineBlockFrequencyInfo &MBFI,
+ AAResults &AAResults);
+
+/// Implementation of the above, which is also more easily unittestable.
+RegAllocScore calculateRegAllocScore(
+ const MachineFunction &MF,
+ llvm::function_ref<double(const MachineBasicBlock &)> GetBBFreq,
+ llvm::function_ref<bool(const MachineInstr &)> IsTriviallyRematerializable);
+} // end namespace llvm
+
+#endif // LLVM_CODEGEN_REGALLOCSCORE_H_
--- /dev/null
+//===- MachineInstrTest.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../lib/CodeGen/RegAllocScore.h"
+#include "llvm/ADT/Triple.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineMemOperand.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/TargetFrameLowering.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/CodeGen/TargetSubtargetInfo.h"
+#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/ModuleSlotTracker.h"
+#include "llvm/MC/MCAsmInfo.h"
+#include "llvm/MC/MCSymbol.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Target/TargetOptions.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+extern cl::opt<double> CopyWeight;
+extern cl::opt<double> LoadWeight;
+extern cl::opt<double> StoreWeight;
+extern cl::opt<double> CheapRematWeight;
+extern cl::opt<double> ExpensiveRematWeight;
+
+namespace {
+// Include helper functions to ease the manipulation of MachineFunctions.
+#include "MFCommon.inc"
+
+// MachineFunction::CreateMachineInstr doesn't copy the MCInstrDesc, it
+// takes its address. So we want a bunch of pre-allocated mock MCInstrDescs.
+#define MOCK_INSTR(MACRO) \
+ MACRO(Copy, TargetOpcode::COPY, 0) \
+ MACRO(Load, 0, 1ULL << MCID::MayLoad) \
+ MACRO(Store, 0, 1ULL << MCID::MayStore) \
+ MACRO(LoadStore, 0, (1ULL << MCID::MayLoad) | (1ULL << MCID::MayStore)) \
+ MACRO(CheapRemat, 0, 1ULL << MCID::CheapAsAMove) \
+ MACRO(ExpensiveRemat, 0, 0) \
+ MACRO(Dbg, TargetOpcode::DBG_LABEL, \
+ (1ULL << MCID::MayLoad) | (1ULL << MCID::MayStore)) \
+ MACRO(InlAsm, TargetOpcode::INLINEASM, \
+ (1ULL << MCID::MayLoad) | (1ULL << MCID::MayStore)) \
+ MACRO(Kill, TargetOpcode::KILL, \
+ (1ULL << MCID::MayLoad) | (1ULL << MCID::MayStore))
+
+enum MockInstrId {
+#define MOCK_INSTR_ID(ID, IGNORE, IGNORE2) ID,
+ MOCK_INSTR(MOCK_INSTR_ID)
+#undef MOCK_INSTR_ID
+ TotalMockInstrs
+};
+
+const std::array<MCInstrDesc, MockInstrId::TotalMockInstrs> MockInstrDescs{{
+#define MOCK_SPEC(IGNORE, OPCODE, FLAGS) \
+ {OPCODE, 0, 0, 0, 0, FLAGS, 0, nullptr, nullptr, nullptr},
+ MOCK_INSTR(MOCK_SPEC)
+#undef MOCK_SPEC
+}};
+
+MachineInstr *createMockCopy(MachineFunction &MF) {
+ return MF.CreateMachineInstr(MockInstrDescs[MockInstrId::Copy], DebugLoc());
+}
+
+MachineInstr *createMockLoad(MachineFunction &MF) {
+ return MF.CreateMachineInstr(MockInstrDescs[MockInstrId::Load], DebugLoc());
+}
+
+MachineInstr *createMockStore(MachineFunction &MF) {
+ return MF.CreateMachineInstr(MockInstrDescs[MockInstrId::Store], DebugLoc());
+}
+
+MachineInstr *createMockLoadStore(MachineFunction &MF) {
+ return MF.CreateMachineInstr(MockInstrDescs[MockInstrId::LoadStore],
+ DebugLoc());
+}
+
+MachineInstr *createMockCheapRemat(MachineFunction &MF) {
+ return MF.CreateMachineInstr(MockInstrDescs[MockInstrId::CheapRemat],
+ DebugLoc());
+}
+
+MachineInstr *createMockExpensiveRemat(MachineFunction &MF) {
+ return MF.CreateMachineInstr(MockInstrDescs[MockInstrId::ExpensiveRemat],
+ DebugLoc());
+}
+
+MachineInstr *createMockDebug(MachineFunction &MF) {
+ return MF.CreateMachineInstr(MockInstrDescs[MockInstrId::Dbg], DebugLoc());
+}
+
+MachineInstr *createMockKill(MachineFunction &MF) {
+ return MF.CreateMachineInstr(MockInstrDescs[MockInstrId::Kill], DebugLoc());
+}
+
+MachineInstr *createMockInlineAsm(MachineFunction &MF) {
+ return MF.CreateMachineInstr(MockInstrDescs[MockInstrId::InlAsm], DebugLoc());
+}
+
+TEST(RegAllocScoreTest, SkipDebugKillInlineAsm) {
+ LLVMContext Ctx;
+ Module Mod("Module", Ctx);
+ auto MF = createMachineFunction(Ctx, Mod);
+
+ auto *MBB = MF->CreateMachineBasicBlock();
+ MF->insert(MF->end(), MBB);
+ auto MBBFreqMock = [&](const MachineBasicBlock &_MBB) -> double {
+ assert(&_MBB == MBB);
+ return 0.5;
+ };
+ auto Next = MBB->end();
+ Next = MBB->insertAfter(Next, createMockInlineAsm(*MF));
+ Next = MBB->insertAfter(Next, createMockDebug(*MF));
+ Next = MBB->insertAfter(Next, createMockKill(*MF));
+ const auto Score = llvm::calculateRegAllocScore(
+ *MF, MBBFreqMock, [](const MachineInstr &) { return false; });
+ ASSERT_EQ(MF->size(), 1U);
+ ASSERT_EQ(Score, RegAllocScore());
+}
+
+TEST(RegAllocScoreTest, Counts) {
+ LLVMContext Ctx;
+ Module Mod("Module", Ctx);
+ auto MF = createMachineFunction(Ctx, Mod);
+
+ auto *MBB1 = MF->CreateMachineBasicBlock();
+ auto *MBB2 = MF->CreateMachineBasicBlock();
+ MF->insert(MF->end(), MBB1);
+ MF->insert(MF->end(), MBB2);
+ const double Freq1 = 0.5;
+ const double Freq2 = 10.0;
+ auto MBBFreqMock = [&](const MachineBasicBlock &MBB) -> double {
+ if (&MBB == MBB1)
+ return Freq1;
+ if (&MBB == MBB2)
+ return Freq2;
+ assert(false && "We only created 2 basic blocks");
+ };
+ auto Next = MBB1->end();
+ Next = MBB1->insertAfter(Next, createMockCopy(*MF));
+ Next = MBB1->insertAfter(Next, createMockLoad(*MF));
+ Next = MBB1->insertAfter(Next, createMockLoad(*MF));
+ Next = MBB1->insertAfter(Next, createMockStore(*MF));
+ auto *CheapRemat = createMockCheapRemat(*MF);
+ MBB1->insertAfter(Next, CheapRemat);
+ Next = MBB2->end();
+ Next = MBB2->insertAfter(Next, createMockLoad(*MF));
+ Next = MBB2->insertAfter(Next, createMockStore(*MF));
+ Next = MBB2->insertAfter(Next, createMockLoadStore(*MF));
+ auto *ExpensiveRemat = createMockExpensiveRemat(*MF);
+ MBB2->insertAfter(Next, ExpensiveRemat);
+ auto IsRemat = [&](const MachineInstr &MI) {
+ return &MI == CheapRemat || &MI == ExpensiveRemat;
+ };
+ ASSERT_EQ(MF->size(), 2U);
+ const auto TotalScore =
+ llvm::calculateRegAllocScore(*MF, MBBFreqMock, IsRemat);
+ ASSERT_EQ(Freq1, TotalScore.copyCounts());
+ ASSERT_EQ(2.0 * Freq1 + Freq2, TotalScore.loadCounts());
+ ASSERT_EQ(Freq1 + Freq2, TotalScore.storeCounts());
+ ASSERT_EQ(Freq2, TotalScore.loadStoreCounts());
+ ASSERT_EQ(Freq1, TotalScore.cheapRematCounts());
+ ASSERT_EQ(Freq2, TotalScore.expensiveRematCounts());
+ ASSERT_EQ(TotalScore.getScore(),
+ TotalScore.copyCounts() * CopyWeight +
+ TotalScore.loadCounts() * LoadWeight +
+ TotalScore.storeCounts() * StoreWeight +
+ TotalScore.loadStoreCounts() * (LoadWeight + StoreWeight) +
+ TotalScore.cheapRematCounts() * CheapRematWeight +
+ TotalScore.expensiveRematCounts() * ExpensiveRematWeight
+
+ );
+}
+} // end namespace