[WebAssembly] Eliminate range checks on br_tables
authorThomas Lively <tlively@google.com>
Tue, 2 Jun 2020 20:14:27 +0000 (13:14 -0700)
committerThomas Lively <tlively@google.com>
Tue, 2 Jun 2020 20:14:27 +0000 (13:14 -0700)
Summary:
Jump tables for most targets cannot handle out of range indices by
themselves, so LLVM emits range checks to guard the jump
tables. WebAssembly, on the other hand, implements jump tables using
the br_table instruction, which takes a default branch target as an
operand, making the range checks redundant. This patch introduces a
new MachineFunction pass in the WebAssembly backend to find and
eliminate the redundant range checks.

Reviewers: aheejin, dschuff

Subscribers: mgorny, sbc100, jgravelle-google, hiraditya, sunfish, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80863

llvm/lib/Target/WebAssembly/CMakeLists.txt
llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
llvm/lib/Target/WebAssembly/WebAssembly.h
llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp [new file with mode: 0644]
llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp
llvm/test/CodeGen/WebAssembly/cfg-stackify.ll
llvm/test/CodeGen/WebAssembly/indirectbr.ll
llvm/test/CodeGen/WebAssembly/stack-insts.ll
llvm/test/CodeGen/WebAssembly/switch-unreachable-default.ll [new file with mode: 0644]
llvm/test/CodeGen/WebAssembly/switch.ll

index 0474242..b730172 100644 (file)
@@ -24,6 +24,7 @@ add_llvm_target(WebAssemblyCodeGen
   WebAssemblyExceptionInfo.cpp
   WebAssemblyExplicitLocals.cpp
   WebAssemblyFastISel.cpp
+  WebAssemblyFixBrTableDefaults.cpp
   WebAssemblyFixIrreducibleControlFlow.cpp
   WebAssemblyFixFunctionBitcasts.cpp
   WebAssemblyFrameLowering.cpp
index 79988b0..971b50e 100644 (file)
@@ -459,6 +459,18 @@ inline bool isCallIndirect(unsigned Opc) {
   }
 }
 
+inline bool isBrTable(const MachineInstr &MI) {
+  switch (MI.getOpcode()) {
+  case WebAssembly::BR_TABLE_I32:
+  case WebAssembly::BR_TABLE_I32_S:
+  case WebAssembly::BR_TABLE_I64:
+  case WebAssembly::BR_TABLE_I64_S:
+    return true;
+  default:
+    return false;
+  }
+}
+
 inline bool isMarker(unsigned Opc) {
   switch (Opc) {
   case WebAssembly::BLOCK:
index f36ca7d..9ce02f7 100644 (file)
@@ -44,6 +44,7 @@ FunctionPass *createWebAssemblyOptimizeLiveIntervals();
 FunctionPass *createWebAssemblyMemIntrinsicResults();
 FunctionPass *createWebAssemblyRegStackify();
 FunctionPass *createWebAssemblyRegColoring();
+FunctionPass *createWebAssemblyFixBrTableDefaults();
 FunctionPass *createWebAssemblyFixIrreducibleControlFlow();
 FunctionPass *createWebAssemblyLateEHPrepare();
 FunctionPass *createWebAssemblyCFGSort();
@@ -68,6 +69,7 @@ void initializeWebAssemblyOptimizeLiveIntervalsPass(PassRegistry &);
 void initializeWebAssemblyMemIntrinsicResultsPass(PassRegistry &);
 void initializeWebAssemblyRegStackifyPass(PassRegistry &);
 void initializeWebAssemblyRegColoringPass(PassRegistry &);
+void initializeWebAssemblyFixBrTableDefaultsPass(PassRegistry &);
 void initializeWebAssemblyFixIrreducibleControlFlowPass(PassRegistry &);
 void initializeWebAssemblyLateEHPreparePass(PassRegistry &);
 void initializeWebAssemblyExceptionInfoPass(PassRegistry &);
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp
new file mode 100644 (file)
index 0000000..0b081e8
--- /dev/null
@@ -0,0 +1,129 @@
+//=- WebAssemblyFixBrTableDefaults.cpp - Fix br_table default branch targets -//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file implements a pass that eliminates redundant range checks
+/// guarding br_table instructions. Since jump tables on most targets cannot
+/// handle out of range indices, LLVM emits these checks before most jump
+/// tables. But br_table takes a default branch target as an argument, so it
+/// does not need the range checks.
+///
+//===----------------------------------------------------------------------===//
+
+#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
+#include "WebAssembly.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/Pass.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "wasm-fix-br-table-defaults"
+
+namespace {
+
+class WebAssemblyFixBrTableDefaults final : public MachineFunctionPass {
+  StringRef getPassName() const override {
+    return "WebAssembly Fix br_table Defaults";
+  }
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+
+public:
+  static char ID; // Pass identification, replacement for typeid
+  WebAssemblyFixBrTableDefaults() : MachineFunctionPass(ID) {}
+};
+
+char WebAssemblyFixBrTableDefaults::ID = 0;
+
+// `MI` is a br_table instruction missing its default target argument. This
+// function finds and adds the default target argument and removes any redundant
+// range check preceding the br_table.
+MachineBasicBlock *fixBrTable(MachineInstr &MI, MachineBasicBlock *MBB,
+                              MachineFunction &MF) {
+  // Get the header block, which contains the redundant range check.
+  assert(MBB->pred_size() == 1 && "Expected a single guard predecessor");
+  auto *HeaderMBB = *MBB->pred_begin();
+
+  // Find the conditional jump to the default target. If it doesn't exist, the
+  // default target is unreachable anyway, so we can choose anything.
+  auto &JumpMII = --HeaderMBB->end();
+  while (JumpMII->getOpcode() != WebAssembly::BR_IF &&
+         JumpMII != HeaderMBB->begin()) {
+    --JumpMII;
+  }
+  if (JumpMII->getOpcode() == WebAssembly::BR_IF) {
+    // Install the default target and remove the jumps in the header.
+    auto *DefaultMBB = JumpMII->getOperand(0).getMBB();
+    assert(DefaultMBB != MBB && "Expected conditional jump to default target");
+    MI.addOperand(MF, MachineOperand::CreateMBB(DefaultMBB));
+    HeaderMBB->erase(JumpMII, HeaderMBB->end());
+  } else {
+    // Arbitrarily choose the first jump target as the default.
+    auto *SomeMBB = MI.getOperand(1).getMBB();
+    MI.addOperand(MachineOperand::CreateMBB(SomeMBB));
+  }
+
+  // Splice the jump table into the header.
+  HeaderMBB->splice(HeaderMBB->end(), MBB, MBB->begin(), MBB->end());
+
+  // Update CFG to skip the old jump table block. Remove shared successors
+  // before transferring to avoid duplicated successors.
+  HeaderMBB->removeSuccessor(MBB);
+  for (auto &Succ : MBB->successors())
+    if (HeaderMBB->isSuccessor(Succ))
+      HeaderMBB->removeSuccessor(Succ);
+  HeaderMBB->transferSuccessorsAndUpdatePHIs(MBB);
+
+  // Remove the old jump table block from the function
+  MF.erase(MBB);
+
+  return HeaderMBB;
+}
+
+bool WebAssemblyFixBrTableDefaults::runOnMachineFunction(MachineFunction &MF) {
+  LLVM_DEBUG(dbgs() << "********** Fixing br_table Default Targets **********\n"
+                       "********** Function: "
+                    << MF.getName() << '\n');
+
+  bool Changed = false;
+  SmallPtrSet<MachineBasicBlock *, 16> MBBSet;
+  for (auto &MBB : MF)
+    MBBSet.insert(&MBB);
+
+  while (!MBBSet.empty()) {
+    MachineBasicBlock *MBB = *MBBSet.begin();
+    MBBSet.erase(MBB);
+    for (auto &MI : *MBB) {
+      if (WebAssembly::isBrTable(MI)) {
+        auto *Fixed = fixBrTable(MI, MBB, MF);
+        MBBSet.erase(Fixed);
+        Changed = true;
+        break;
+      }
+    }
+  }
+
+  if (Changed) {
+    // We rewrote part of the function; recompute relevant things.
+    MF.RenumberBlocks();
+    return true;
+  }
+
+  return false;
+}
+
+} // end anonymous namespace
+
+INITIALIZE_PASS(WebAssemblyFixBrTableDefaults, DEBUG_TYPE,
+                "Removes range checks and sets br_table default targets", false,
+                false);
+
+FunctionPass *llvm::createWebAssemblyFixBrTableDefaults() {
+  return new WebAssemblyFixBrTableDefaults();
+}
index 1b85ecd..83e4ed7 100644 (file)
@@ -1279,11 +1279,8 @@ SDValue WebAssemblyTargetLowering::LowerBR_JT(SDValue Op,
   for (auto MBB : MBBs)
     Ops.push_back(DAG.getBasicBlock(MBB));
 
-  // TODO: For now, we just pick something arbitrary for a default case for now.
-  // We really want to sniff out the guard and put in the real default case (and
-  // delete the guard).
-  Ops.push_back(DAG.getBasicBlock(MBBs[0]));
-
+  // Do not add the default case for now. It will be added in
+  // WebAssemblyFixBrTableDefaults.
   return DAG.getNode(WebAssemblyISD::BR_TABLE, DL, MVT::Other, Ops);
 }
 
index ca09842..7bf655c 100644 (file)
@@ -406,6 +406,10 @@ bool WebAssemblyPassConfig::addInstSelector() {
   // it's inconvenient to collect. Collect it now, and update the immediate
   // operands.
   addPass(createWebAssemblySetP2AlignOperands());
+
+  // Eliminate range checks and add default targets to br_table instructions.
+  addPass(createWebAssemblyFixBrTableDefaults());
+
   return false;
 }
 
index 8c16a46..ff04ddc 100644 (file)
@@ -382,14 +382,14 @@ if.end:
 
 ; CHECK-LABEL: test4:
 ; CHECK-NEXT: .functype test4 (i32) -> (){{$}}
-; CHECK:      block   {{$}}
 ; CHECK-NEXT: block   {{$}}
-; CHECK:      br_if     0, $pop{{[0-9]+}}{{$}}
-; CHECK:      br        1{{$}}
-; CHECK-NEXT: .LBB{{[0-9]+}}_2:
+; CHECK-NEXT: block   {{$}}
+; CHECK-NEXT: br_table   $0, 1, 1, 1, 1, 1, 0{{$}}
+; CHECK-NEXT: .LBB{{[0-9]+}}_1:
 ; CHECK-NEXT: end_block{{$}}
-; CHECK-NEXT: br_table   $0, 0, 0, 0, 0, 0, 0{{$}}
-; CHECK-NEXT: .LBB{{[0-9]+}}_3:
+; CHECK-NEXT: i32.const $push[[C:[0-9]+]]=, 622{{$}}
+; CHECK-NEXT: i32.eq $drop=, $0, $pop[[C]]{{$}}
+; CHECK-NEXT: .LBB{{[0-9]+}}_2:
 ; CHECK-NEXT: end_block{{$}}
 ; CHECK-NEXT: return{{$}}
 define void @test4(i32 %t) {
@@ -649,20 +649,16 @@ end:
 ; CHECK:       br_if     0, {{[^,]+}}{{$}}
 ; CHECK-NEXT:  end_loop{{$}}
 ; CHECK-NEXT:  block   {{$}}
-; CHECK:       br_if     0, {{[^,]+}}{{$}}
-; CHECK:       br        3{{$}}
-; CHECK-NEXT:  .LBB{{[0-9]+}}_7:
-; CHECK-NEXT:  end_block{{$}}
-; CHECK:       block   {{$}}
-; CHECK-NEXT:  br_table   $0, 0, 3, 1, 2, 0
-; CHECK-NEXT:  .LBB{{[0-9]+}}_8:
+; CHECK-NOT:   br_if
+; CHECK:       br_table   $pop{{[^,]+}}, 0, 3, 1, 2, 3
+; CHECK-NEXT:  .LBB{{[0-9]+}}_6:
 ; CHECK-NEXT:  end_block{{$}}
 ; CHECK-NEXT:  end_loop{{$}}
 ; CHECK-NEXT:  return{{$}}
-; CHECK-NEXT:  .LBB{{[0-9]+}}_9:
+; CHECK-NEXT:  .LBB{{[0-9]+}}_7:
 ; CHECK-NEXT:  end_block{{$}}
 ; CHECK:       br        0{{$}}
-; CHECK-NEXT:  .LBB{{[0-9]+}}_10:
+; CHECK-NEXT:  .LBB{{[0-9]+}}_8:
 ; CHECK-NEXT:  end_loop{{$}}
 define void @test10() {
 bb0:
@@ -767,25 +763,22 @@ bb8:
 
 ; CHECK-LABEL: test12:
 ; CHECK:       .LBB{{[0-9]+}}_1:
-; CHECK-NEXT:  loop    {{$}}
 ; CHECK-NEXT:  block   {{$}}
+; CHECK-NEXT:  loop    {{$}}
 ; CHECK-NEXT:  block   {{$}}
 ; CHECK-NEXT:  block   {{$}}
+; CHECK:       br_table  {{[^,]+}}, 1, 3, 3, 3, 1, 0{{$}}
+; CHECK-NEXT:  .LBB{{[0-9]+}}_2:
+; CHECK-NEXT:  end_block{{$}}
 ; CHECK:       br_if     0, {{[^,]+}}{{$}}
 ; CHECK:       br_if     2, {{[^,]+}}{{$}}
-; CHECK:       br_if     1, {{[^,]+}}{{$}}
-; CHECK-NEXT:  br        2{{$}}
 ; CHECK-NEXT:  .LBB{{[0-9]+}}_4:
 ; CHECK-NEXT:  end_block{{$}}
-; CHECK-NEXT:  br_table   $2, 1, 0, 0, 0, 1, 1{{$}}
+; CHECK:       br        0{{$}}
 ; CHECK-NEXT:  .LBB{{[0-9]+}}_5:
+; CHECK-NEXT:  end_loop{{$}}
 ; CHECK-NEXT:  end_block{{$}}
 ; CHECK-NEXT:  return{{$}}
-; CHECK-NEXT:  .LBB{{[0-9]+}}_6:
-; CHECK-NEXT:  end_block{{$}}
-; CHECK:       br        0{{$}}
-; CHECK-NEXT:  .LBB{{[0-9]+}}_7:
-; CHECK-NEXT:  end_loop{{$}}
 define void @test12(i8* %arg) {
 bb:
   br label %bb1
index d32f941..da73761 100644 (file)
@@ -13,20 +13,36 @@ target triple = "wasm32"
 
 ; Just check the barest skeleton of the structure
 ; CHECK-LABEL: test1:
+; CHECK: block
+; CHECK: block
+; CHECK: block
+; CHECK: block
 ; CHECK: i32.load
-; CHECK: i32.load $[[DEST:.+]]=
+; CHECK: i32.load
+; CHECK: i32.const
+; CHECK: i32.add $push[[DEST:.+]]=
+; CHECK: br_table $pop[[DEST]]
+; CHECK: end_block
+; CHECK: end_block
+; CHECK: end_block
+; CHECK: end_block
 ; CHECK: loop
 ; CHECK: block
 ; CHECK: block
+; CHECK: block
+; CHECK: block
+; CHECK: br_table ${{[^,]+}}, 0, 1, 2, 2
+; CHECK: end_block
 ; CHECK: end_block
+; CHECK: end_block
+; CHECK: block
 ; CHECK: block
 ; CHECK: block
-; CHECK: br_table $[[DEST]]
+; CHECK: br_table ${{[^,]+}}, 1, 2, 0
+; CHECK: end_block
 ; CHECK: end_block
 ; CHECK: end_block
-; CHECK: i32.load $[[DEST]]=
 ; CHECK: end_loop
-
 ; CHECK: test1.targets:
 ; CHECK-NEXT: .int32
 ; CHECK-NEXT: .int32
index c4ccddd..506f43c 100644 (file)
@@ -8,7 +8,7 @@ declare void @foo1()
 
 ; Tests if br_table is printed correctly with a tab.
 ; CHECK-LABEL: test0:
-; CHECK: br_table {0, 1, 0, 1, 0}
+; CHECK: br_table {0, 1, 0, 1, 2}
 define void @test0(i32 %n) {
 entry:
   switch i32 %n, label %sw.epilog [
diff --git a/llvm/test/CodeGen/WebAssembly/switch-unreachable-default.ll b/llvm/test/CodeGen/WebAssembly/switch-unreachable-default.ll
new file mode 100644 (file)
index 0000000..f862a45
--- /dev/null
@@ -0,0 +1,38 @@
+; RUN: llc < %s -asm-verbose=false -verify-machineinstrs | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128"
+target triple = "wasm32-unknown-unknown"
+
+; Test that switches are lowered correctly in the presence of an
+; unreachable default branch target.
+
+; CHECK-LABEL: foo:
+; CHECK-NEXT:    .functype foo (i32) -> (i32)
+; CHECK-NEXT:    block
+; CHECK-NEXT:    block
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    br_table {0, 1, 0}
+; CHECK-NEXT:  .LBB0_1:
+; CHECK-NEXT:    end_block
+; CHECK-NEXT:    i32.const 0
+; CHECK-NEXT:    return
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    end_block
+; CHECK-NEXT:    i32.const 1
+; CHECK-NEXT:    end_function
+define i32 @foo(i32 %x) {
+entry:
+  switch i32 %x, label %unreachable [
+    i32 0, label %bb0
+    i32 1, label %bb1
+  ]
+
+bb0:
+  ret i32 0
+
+bb1:
+  ret i32 1
+
+unreachable:
+  unreachable
+}
index 1b0dfc8..3a9da70 100644 (file)
@@ -21,20 +21,20 @@ declare void @foo5()
 ; CHECK: block   {{$}}
 ; CHECK: block   {{$}}
 ; CHECK: block   {{$}}
-; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 0{{$}}
-; CHECK: .LBB{{[0-9]+}}_2:
+; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 6{{$}}
+; CHECK: .LBB{{[0-9]+}}_1:
 ; CHECK:   call foo0{{$}}
-; CHECK: .LBB{{[0-9]+}}_3:
+; CHECK: .LBB{{[0-9]+}}_2:
 ; CHECK:   call foo1{{$}}
-; CHECK: .LBB{{[0-9]+}}_4:
+; CHECK: .LBB{{[0-9]+}}_3:
 ; CHECK:   call foo2{{$}}
-; CHECK: .LBB{{[0-9]+}}_5:
+; CHECK: .LBB{{[0-9]+}}_4:
 ; CHECK:   call foo3{{$}}
-; CHECK: .LBB{{[0-9]+}}_6:
+; CHECK: .LBB{{[0-9]+}}_5:
 ; CHECK:   call foo4{{$}}
-; CHECK: .LBB{{[0-9]+}}_7:
+; CHECK: .LBB{{[0-9]+}}_6:
 ; CHECK:   call foo5{{$}}
-; CHECK: .LBB{{[0-9]+}}_8:
+; CHECK: .LBB{{[0-9]+}}_7:
 ; CHECK:   return{{$}}
 define void @bar32(i32 %n) {
 entry:
@@ -101,20 +101,20 @@ sw.epilog:                                        ; preds = %entry, %sw.bb.5, %s
 ; CHECK: block   {{$}}
 ; CHECK: block   {{$}}
 ; CHECK: block   {{$}}
-; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 0{{$}}
-; CHECK: .LBB{{[0-9]+}}_2:
+; CHECK: br_table {{[^,]+}}, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 6{{$}}
+; CHECK: .LBB{{[0-9]+}}_1:
 ; CHECK:   call foo0{{$}}
-; CHECK: .LBB{{[0-9]+}}_3:
+; CHECK: .LBB{{[0-9]+}}_2:
 ; CHECK:   call foo1{{$}}
-; CHECK: .LBB{{[0-9]+}}_4:
+; CHECK: .LBB{{[0-9]+}}_3:
 ; CHECK:   call foo2{{$}}
-; CHECK: .LBB{{[0-9]+}}_5:
+; CHECK: .LBB{{[0-9]+}}_4:
 ; CHECK:   call foo3{{$}}
-; CHECK: .LBB{{[0-9]+}}_6:
+; CHECK: .LBB{{[0-9]+}}_5:
 ; CHECK:   call foo4{{$}}
-; CHECK: .LBB{{[0-9]+}}_7:
+; CHECK: .LBB{{[0-9]+}}_6:
 ; CHECK:   call foo5{{$}}
-; CHECK: .LBB{{[0-9]+}}_8:
+; CHECK: .LBB{{[0-9]+}}_7:
 ; CHECK:   return{{$}}
 define void @bar64(i64 %n) {
 entry: