[RISCV][MC] Refine MCInstrAnalysis based on registers used
authorJob Noorman <jnoorman@igalia.com>
Wed, 17 May 2023 09:43:38 +0000 (11:43 +0200)
committerJob Noorman <jnoorman@igalia.com>
Wed, 17 May 2023 09:43:39 +0000 (11:43 +0200)
MCInstrAnalysis provides a number of methods to query properties of
instructions (e.g., isTerminator(), isCall(),...). The default
implementation of these methods forwards the query to MCInstrDesc which
returns information based on various RISCVInstrInfo*.td files.

Since the info in MCInstrDesc is based on opcodes only, it is often
quite inaccurate. For example, JAL/JALR are never recognized as
terminators or branches while they certainly can be. However,
MCInstrAnalysis has access to the full MCInst so can improve accuracy by
inspecting registers used by the instruction.

This patch refines the following MCInstrAnalysis methods:
- isTerminator: JAL/JALR with RD=X0;
- isCall: JAL/JALR with RD!=X0
- isReturn: JALR/C_JR with RD=X0, RS1 in {X1, X5}
- isBranch: JAL/JALR/C_JR with RD=X0, RS1 not in {X1, X5};
- isUnconditionalBranch: JAL/JALR/C_JR with RD=X0, RS1 not in {X1, X5};
- isIndirectBranch: JALR/C_JR with RD=X0, RS1 not in {X1, X5};

Note that the reason for this patch is to simplify the RISCV target in
BOLT. While it's possible to implement everything there, it seems more
logical to implement it directly in the RISCV backend as other tools
might also be able to benefit from it.

Reviewed By: craig.topper, MaskRay

Differential Revision: https://reviews.llvm.org/D146438

llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCTargetDesc.cpp
llvm/unittests/Target/RISCV/CMakeLists.txt [new file with mode: 0644]
llvm/unittests/Target/RISCV/MCInstrAnalysisTest.cpp [new file with mode: 0644]

index f9f7ac2..75af5c2 100644 (file)
@@ -142,6 +142,96 @@ public:
 
     return false;
   }
+
+  bool isTerminator(const MCInst &Inst) const override {
+    if (MCInstrAnalysis::isTerminator(Inst))
+      return true;
+
+    switch (Inst.getOpcode()) {
+    default:
+      return false;
+    case RISCV::JAL:
+    case RISCV::JALR:
+      return Inst.getOperand(0).getReg() == RISCV::X0;
+    }
+  }
+
+  bool isCall(const MCInst &Inst) const override {
+    if (MCInstrAnalysis::isCall(Inst))
+      return true;
+
+    switch (Inst.getOpcode()) {
+    default:
+      return false;
+    case RISCV::JAL:
+    case RISCV::JALR:
+      return Inst.getOperand(0).getReg() != RISCV::X0;
+    }
+  }
+
+  bool isReturn(const MCInst &Inst) const override {
+    if (MCInstrAnalysis::isReturn(Inst))
+      return true;
+
+    switch (Inst.getOpcode()) {
+    default:
+      return false;
+    case RISCV::JALR:
+      return Inst.getOperand(0).getReg() == RISCV::X0 &&
+             maybeReturnAddress(Inst.getOperand(1).getReg());
+    case RISCV::C_JR:
+      return maybeReturnAddress(Inst.getOperand(0).getReg());
+    }
+  }
+
+  bool isBranch(const MCInst &Inst) const override {
+    if (MCInstrAnalysis::isBranch(Inst))
+      return true;
+
+    return isBranchImpl(Inst);
+  }
+
+  bool isUnconditionalBranch(const MCInst &Inst) const override {
+    if (MCInstrAnalysis::isUnconditionalBranch(Inst))
+      return true;
+
+    return isBranchImpl(Inst);
+  }
+
+  bool isIndirectBranch(const MCInst &Inst) const override {
+    if (MCInstrAnalysis::isIndirectBranch(Inst))
+      return true;
+
+    switch (Inst.getOpcode()) {
+    default:
+      return false;
+    case RISCV::JALR:
+      return Inst.getOperand(0).getReg() == RISCV::X0 &&
+             !maybeReturnAddress(Inst.getOperand(1).getReg());
+    case RISCV::C_JR:
+      return !maybeReturnAddress(Inst.getOperand(0).getReg());
+    }
+  }
+
+private:
+  static bool maybeReturnAddress(unsigned Reg) {
+    // X1 is used for normal returns, X5 for returns from outlined functions.
+    return Reg == RISCV::X1 || Reg == RISCV::X5;
+  }
+
+  static bool isBranchImpl(const MCInst &Inst) {
+    switch (Inst.getOpcode()) {
+    default:
+      return false;
+    case RISCV::JAL:
+      return Inst.getOperand(0).getReg() == RISCV::X0;
+    case RISCV::JALR:
+      return Inst.getOperand(0).getReg() == RISCV::X0 &&
+             !maybeReturnAddress(Inst.getOperand(1).getReg());
+    case RISCV::C_JR:
+      return !maybeReturnAddress(Inst.getOperand(0).getReg());
+    }
+  }
 };
 
 } // end anonymous namespace
diff --git a/llvm/unittests/Target/RISCV/CMakeLists.txt b/llvm/unittests/Target/RISCV/CMakeLists.txt
new file mode 100644 (file)
index 0000000..2c757b8
--- /dev/null
@@ -0,0 +1,18 @@
+include_directories(
+  ${LLVM_MAIN_SRC_DIR}/lib/Target/RISCV
+  ${LLVM_BINARY_DIR}/lib/Target/RISCV
+  )
+
+set(LLVM_LINK_COMPONENTS
+  RISCVCodeGen
+  RISCVDesc
+  RISCVInfo
+  TargetParser
+  MC
+  )
+
+add_llvm_target_unittest(RISCVTests
+  MCInstrAnalysisTest.cpp
+  )
+
+set_property(TARGET RISCVTests PROPERTY FOLDER "Tests/UnitTests/TargetTests")
diff --git a/llvm/unittests/Target/RISCV/MCInstrAnalysisTest.cpp b/llvm/unittests/Target/RISCV/MCInstrAnalysisTest.cpp
new file mode 100644 (file)
index 0000000..2ef92c7
--- /dev/null
@@ -0,0 +1,174 @@
+//===- MCInstrAnalysisTest.cpp - RISCVMCInstrAnalysis unit tests ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/MC/MCInstrAnalysis.h"
+#include "MCTargetDesc/RISCVMCTargetDesc.h"
+#include "llvm/MC/MCInstBuilder.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+
+#include "gtest/gtest.h"
+
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+
+class InstrAnalysisTest : public testing::TestWithParam<const char *> {
+protected:
+  std::unique_ptr<const MCInstrInfo> Info;
+  std::unique_ptr<const MCInstrAnalysis> Analysis;
+
+  static void SetUpTestSuite() {
+    LLVMInitializeRISCVTargetInfo();
+    LLVMInitializeRISCVTarget();
+    LLVMInitializeRISCVTargetMC();
+  }
+
+  InstrAnalysisTest() {
+    std::string Error;
+    const Target *TheTarget =
+        TargetRegistry::lookupTarget(Triple::normalize(GetParam()), Error);
+    Info = std::unique_ptr<const MCInstrInfo>(TheTarget->createMCInstrInfo());
+    Analysis = std::unique_ptr<const MCInstrAnalysis>(
+        TheTarget->createMCInstrAnalysis(Info.get()));
+  }
+};
+
+} // namespace
+
+static MCInst jal(unsigned RD) {
+  return MCInstBuilder(RISCV::JAL).addReg(RD).addImm(16);
+}
+
+static MCInst jalr(unsigned RD, unsigned RS1 = RISCV::X10) {
+  return MCInstBuilder(RISCV::JALR).addReg(RD).addReg(RS1).addImm(16);
+}
+
+static MCInst cjr(unsigned RS1) {
+  return MCInstBuilder(RISCV::C_JR).addReg(RS1);
+}
+
+static MCInst cj() { return MCInstBuilder(RISCV::C_J).addImm(16); }
+static MCInst cjal() { return MCInstBuilder(RISCV::C_JAL).addImm(16); }
+
+static MCInst cjalr(unsigned RS1) {
+  return MCInstBuilder(RISCV::C_JALR).addReg(RS1);
+}
+
+static MCInst beq() {
+  return MCInstBuilder(RISCV::BEQ)
+      .addReg(RISCV::X0)
+      .addReg(RISCV::X1)
+      .addImm(32);
+}
+
+static MCInst cbeqz() {
+  return MCInstBuilder(RISCV::C_BEQZ).addReg(RISCV::X1).addImm(32);
+}
+
+TEST_P(InstrAnalysisTest, IsTerminator) {
+  EXPECT_TRUE(Analysis->isTerminator(beq()));
+  EXPECT_TRUE(Analysis->isTerminator(cbeqz()));
+  EXPECT_TRUE(Analysis->isTerminator(jal(RISCV::X0)));
+  EXPECT_FALSE(Analysis->isTerminator(jal(RISCV::X5)));
+  EXPECT_TRUE(Analysis->isTerminator(jalr(RISCV::X0)));
+  EXPECT_FALSE(Analysis->isTerminator(jalr(RISCV::X5)));
+  EXPECT_TRUE(Analysis->isTerminator(cj()));
+  EXPECT_FALSE(Analysis->isTerminator(cjal()));
+}
+
+TEST_P(InstrAnalysisTest, IsCall) {
+  EXPECT_FALSE(Analysis->isCall(beq()));
+  EXPECT_FALSE(Analysis->isCall(cbeqz()));
+  EXPECT_FALSE(Analysis->isCall(jal(RISCV::X0)));
+  EXPECT_TRUE(Analysis->isCall(jal(RISCV::X1)));
+  EXPECT_TRUE(Analysis->isCall(jalr(RISCV::X1, RISCV::X1)));
+  EXPECT_FALSE(Analysis->isCall(jalr(RISCV::X0, RISCV::X5)));
+  EXPECT_FALSE(Analysis->isCall(cj()));
+  EXPECT_FALSE(Analysis->isCall(cjr(RISCV::X5)));
+  EXPECT_TRUE(Analysis->isCall(cjal()));
+  EXPECT_TRUE(Analysis->isCall(cjalr(RISCV::X5)));
+}
+
+TEST_P(InstrAnalysisTest, IsReturn) {
+  EXPECT_FALSE(Analysis->isReturn(beq()));
+  EXPECT_FALSE(Analysis->isReturn(cbeqz()));
+  EXPECT_FALSE(Analysis->isReturn(jal(RISCV::X0)));
+  EXPECT_TRUE(Analysis->isReturn(jalr(RISCV::X0, RISCV::X1)));
+  EXPECT_FALSE(Analysis->isReturn(jalr(RISCV::X1, RISCV::X1)));
+  EXPECT_TRUE(Analysis->isReturn(jalr(RISCV::X0, RISCV::X5)));
+  EXPECT_FALSE(Analysis->isReturn(cj()));
+  EXPECT_TRUE(Analysis->isReturn(cjr(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isReturn(cjr(RISCV::X2)));
+  EXPECT_TRUE(Analysis->isReturn(cjr(RISCV::X5)));
+  EXPECT_FALSE(Analysis->isReturn(cjal()));
+  EXPECT_FALSE(Analysis->isReturn(cjalr(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isReturn(cjalr(RISCV::X5)));
+}
+
+TEST_P(InstrAnalysisTest, IsBranch) {
+  EXPECT_TRUE(Analysis->isBranch(beq()));
+  EXPECT_TRUE(Analysis->isBranch(cbeqz()));
+  EXPECT_TRUE(Analysis->isBranch(jal(RISCV::X0)));
+  EXPECT_FALSE(Analysis->isBranch(jal(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isBranch(jal(RISCV::X5)));
+  EXPECT_TRUE(Analysis->isBranch(jalr(RISCV::X0)));
+  EXPECT_FALSE(Analysis->isBranch(jalr(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isBranch(jalr(RISCV::X5)));
+  EXPECT_FALSE(Analysis->isBranch(jalr(RISCV::X0, RISCV::X1)));
+  EXPECT_FALSE(Analysis->isBranch(jalr(RISCV::X0, RISCV::X5)));
+  EXPECT_TRUE(Analysis->isBranch(cj()));
+  EXPECT_TRUE(Analysis->isBranch(cjr(RISCV::X2)));
+  EXPECT_FALSE(Analysis->isBranch(cjr(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isBranch(cjr(RISCV::X5)));
+  EXPECT_FALSE(Analysis->isBranch(cjal()));
+  EXPECT_FALSE(Analysis->isBranch(cjalr(RISCV::X6)));
+  EXPECT_FALSE(Analysis->isBranch(cjalr(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isBranch(cjalr(RISCV::X5)));
+}
+
+TEST_P(InstrAnalysisTest, IsUnconditionalBranch) {
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(beq()));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(cbeqz()));
+  EXPECT_TRUE(Analysis->isUnconditionalBranch(jal(RISCV::X0)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(jal(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(jal(RISCV::X5)));
+  EXPECT_TRUE(Analysis->isUnconditionalBranch(jalr(RISCV::X0)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(jalr(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(jalr(RISCV::X5)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(jalr(RISCV::X0, RISCV::X1)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(jalr(RISCV::X0, RISCV::X5)));
+  EXPECT_TRUE(Analysis->isUnconditionalBranch(cj()));
+  EXPECT_TRUE(Analysis->isUnconditionalBranch(cjr(RISCV::X2)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(cjr(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(cjr(RISCV::X5)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(cjal()));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(cjalr(RISCV::X6)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(cjalr(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isUnconditionalBranch(cjalr(RISCV::X5)));
+}
+
+TEST_P(InstrAnalysisTest, IsIndirectBranch) {
+  EXPECT_FALSE(Analysis->isIndirectBranch(beq()));
+  EXPECT_FALSE(Analysis->isIndirectBranch(cbeqz()));
+  EXPECT_FALSE(Analysis->isIndirectBranch(jal(RISCV::X0)));
+  EXPECT_FALSE(Analysis->isIndirectBranch(jal(RISCV::X1)));
+  EXPECT_TRUE(Analysis->isIndirectBranch(jalr(RISCV::X0)));
+  EXPECT_FALSE(Analysis->isIndirectBranch(jalr(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isIndirectBranch(cj()));
+  EXPECT_TRUE(Analysis->isIndirectBranch(cjr(RISCV::X10)));
+  EXPECT_FALSE(Analysis->isIndirectBranch(cjr(RISCV::X1)));
+  EXPECT_FALSE(Analysis->isIndirectBranch(cjr(RISCV::X5)));
+  EXPECT_FALSE(Analysis->isIndirectBranch(cjal()));
+  EXPECT_FALSE(Analysis->isIndirectBranch(cjalr(RISCV::X5)));
+}
+
+INSTANTIATE_TEST_SUITE_P(RV32And64, InstrAnalysisTest,
+                         testing::Values("riscv32", "riscv64"));