[WebAssembly] support "return" and unreachable code in asm type checker
authorWouter van Oortmerssen <aardappel@gmail.com>
Mon, 1 Nov 2021 20:09:47 +0000 (13:09 -0700)
committerWouter van Oortmerssen <aardappel@gmail.com>
Mon, 1 Nov 2021 22:42:58 +0000 (15:42 -0700)
To support return (it not being supported well was the ground cause for
https://github.com/WebAssembly/wasi-sdk/issues/200) we also have to have
at least a basic notion of unreachable, which in this case just means to stop
type checking until there is an end_block (an incoming control flow edge).
This is conservative (may miss on some type checking opportunities) but is
simple and an improvement over what we had before.

Differential Revision: https://reviews.llvm.org/D112953

llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
llvm/test/MC/WebAssembly/basic-assembly.s

index 7789823..7d1e6c5 100644 (file)
@@ -1114,6 +1114,8 @@ public:
 
   void onEndOfFunction(SMLoc ErrorLoc) {
     TC.endOfFunction(ErrorLoc);
+    // Reset the type checker state.
+    TC.Clear();
 
     // Automatically output a .size directive, so it becomes optional for the
     // user.
index 9e4162a..a6b5d42 100644 (file)
@@ -74,6 +74,9 @@ bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
   // which are mostly not helpful.
   if (TypeErrorThisFunction)
     return true;
+  // If we're currently in unreachable code, we surpress errors as well.
+  if (Unreachable)
+    return true;
   TypeErrorThisFunction = true;
   dumpTypeStack("current stack: ");
   return Parser.Error(ErrorLoc, Msg);
@@ -170,17 +173,18 @@ bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
   return false;
 }
 
-void WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
+bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
   // Check the return types.
   for (auto RVT : llvm::reverse(ReturnTypes)) {
-    popType(ErrorLoc, RVT);
+    if (popType(ErrorLoc, RVT))
+      return true;
   }
   if (!Stack.empty()) {
-    typeError(ErrorLoc,
-              std::to_string(Stack.size()) + " superfluous return values");
+    return typeError(ErrorLoc, std::to_string(Stack.size()) +
+                                   " superfluous return values");
   }
-  // Reset the type checker state.
-  Clear();
+  Unreachable = true;
+  return false;
 }
 
 bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst) {
@@ -219,10 +223,17 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst) {
              Name == "else" || Name == "end_try") {
     if (checkEnd(ErrorLoc))
       return true;
+    if (Name == "end_block")
+      Unreachable = false;
+  } else if (Name == "return") {
+    if (endOfFunction(ErrorLoc))
+      return true;
   } else if (Name == "call_indirect" || Name == "return_call_indirect") {
     // Function value.
     if (popType(ErrorLoc, wasm::ValType::I32)) return true;
     if (checkSig(ErrorLoc, LastSig)) return true;
+    if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
+      return true;
   } else if (Name == "call" || Name == "return_call") {
     const MCSymbolRefExpr *SymRef;
     if (getSymRef(ErrorLoc, Inst, SymRef))
@@ -233,6 +244,8 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst) {
       return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
                                       " missing .functype");
     if (checkSig(ErrorLoc, *Sig)) return true;
+    if (Name == "return_call" && endOfFunction(ErrorLoc))
+      return true;
   } else if (Name == "catch") {
     const MCSymbolRefExpr *SymRef;
     if (getSymRef(ErrorLoc, Inst, SymRef))
@@ -248,6 +261,8 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst) {
   } else if (Name == "ref.null") {
     auto VT = static_cast<wasm::ValType>(Inst.getOperand(0).getImm());
     Stack.push_back(VT);
+  } else if (Name == "unreachable") {
+    Unreachable = true;
   } else {
     // The current instruction is a stack instruction which doesn't have
     // explicit operands that indicate push/pop types, so we get those from
index a15a69b..aa35213 100644 (file)
@@ -32,15 +32,9 @@ class WebAssemblyAsmTypeCheck final {
   SmallVector<wasm::ValType, 4> ReturnTypes;
   wasm::WasmSignature LastSig;
   bool TypeErrorThisFunction = false;
+  bool Unreachable = false;
   bool is64;
 
-  void Clear() {
-    Stack.clear();
-    LocalTypes.clear();
-    ReturnTypes.clear();
-    TypeErrorThisFunction = false;
-  }
-
   void dumpTypeStack(Twine Msg);
   bool typeError(SMLoc ErrorLoc, const Twine &Msg);
   bool popType(SMLoc ErrorLoc, Optional<wasm::ValType> EVT);
@@ -57,8 +51,16 @@ public:
   void funcDecl(const wasm::WasmSignature &Sig);
   void localDecl(const SmallVector<wasm::ValType, 4> &Locals);
   void setLastSig(const wasm::WasmSignature &Sig) { LastSig = Sig; }
-  void endOfFunction(SMLoc ErrorLoc);
+  bool endOfFunction(SMLoc ErrorLoc);
   bool typeCheck(SMLoc ErrorLoc, const MCInst &Inst);
+
+  void Clear() {
+    Stack.clear();
+    LocalTypes.clear();
+    ReturnTypes.clear();
+    TypeErrorThisFunction = false;
+    Unreachable = false;
+  }
 };
 
 } // end namespace llvm
index 4464082..b86172b 100644 (file)
@@ -1,9 +1,10 @@
-# RUN: llvm-mc -triple=wasm32-unknown-unknown -mattr=+reference-types,atomics,+simd128,+nontrapping-fptoint,+exception-handling < %s | FileCheck %s
+# RUN: llvm-mc -triple=wasm32-unknown-unknown -mattr=+tail-call,+reference-types,atomics,+simd128,+nontrapping-fptoint,+exception-handling < %s | FileCheck %s
 # Check that it converts to .o without errors, but don't check any output:
-# RUN: llvm-mc -triple=wasm32-unknown-unknown -filetype=obj -mattr=+reference-types,+atomics,+simd128,+nontrapping-fptoint,+exception-handling -o %t.o < %s
+# RUN: llvm-mc -triple=wasm32-unknown-unknown -filetype=obj -mattr=+tail-call,+reference-types,+atomics,+simd128,+nontrapping-fptoint,+exception-handling -o %t.o < %s
 
 .functype   something1 () -> ()
 .functype   something2 (i64) -> (i32, f64)
+.functype   something3 () -> (i32)
 .globaltype __stack_pointer, i32
 
 empty_func:
@@ -86,6 +87,17 @@ test0:
     else
     end_if
     drop
+    block       void
+    i32.const   2
+    return
+    end_block
+    block       void
+    return_call something3
+    end_block
+    block       void
+    i32.const   3
+    return_call_indirect () -> (i32)
+    end_block
     local.get   4
     local.get   5
     f32x4.add
@@ -215,6 +227,17 @@ empty_fref_table:
 # CHECK-NEXT:      else
 # CHECK-NEXT:      end_if
 # CHECK-NEXT:      drop
+# CHECK-NEXT:      block
+# CHECK-NEXT:      i32.const   2
+# CHECK-NEXT:      return
+# CHECK-NEXT:      end_block
+# CHECK-NEXT:      block
+# CHECK-NEXT:      return_call something3
+# CHECK-NEXT:      end_block
+# CHECK-NEXT:      block
+# CHECK-NEXT:      i32.const   3
+# CHECK-NEXT:      return_call_indirect __indirect_function_table, () -> (i32)
+# CHECK-NEXT:      end_block
 # CHECK-NEXT:      local.get   4
 # CHECK-NEXT:      local.get   5
 # CHECK-NEXT:      f32x4.add