[TSAN] Handle musttail call properly in EscapeEnumerator (and TSAN)
authorXun Li <xun@fb.com>
Tue, 15 Sep 2020 22:19:57 +0000 (15:19 -0700)
committerXun Li <xun@fb.com>
Tue, 15 Sep 2020 22:20:05 +0000 (15:20 -0700)
Call instructions with musttail tag must be optimized as a tailcall, otherwise could lead to incorrect program behavior.
When TSAN is instrumenting functions, it broke the contract by adding a call to the tsan exit function inbetween the musttail call and return instruction, and also inserted exception handling code.
This happend throguh EscapeEnumerator, which adds exception handling code and returns ret instructions as the place to insert instrumentation calls.
This becomes especially problematic for coroutines, because coroutines rely on tail calls to do symmetric transfers properly.
To fix this, this patch moves the location to insert instrumentation calls prior to the musttail call for ret instructions that are following musttail calls, and also does not handle exception for musttail calls.

Differential Revision: https://reviews.llvm.org/D87620

llvm/lib/Transforms/Utils/EscapeEnumerator.cpp
llvm/test/Instrumentation/ThreadSanitizer/tsan_musttail.ll [new file with mode: 0644]

index cae9d9e..dca58bc 100644 (file)
@@ -41,7 +41,27 @@ IRBuilder<> *EscapeEnumerator::Next() {
     if (!isa<ReturnInst>(TI) && !isa<ResumeInst>(TI))
       continue;
 
-    Builder.SetInsertPoint(TI);
+    // If the ret instruction is followed by a musttaill call,
+    // or a bitcast instruction and then a musttail call, we should return
+    // the musttail call as the insertion point to not break the musttail
+    // contract.
+    auto AdjustMustTailCall = [&](Instruction *I) -> Instruction * {
+      auto *RI = dyn_cast<ReturnInst>(I);
+      if (!RI || !RI->getPrevNode())
+        return I;
+      auto *CI = dyn_cast<CallInst>(RI->getPrevNode());
+      if (CI && CI->isMustTailCall())
+        return CI;
+      auto *BI = dyn_cast<BitCastInst>(RI->getPrevNode());
+      if (!BI || !BI->getPrevNode())
+        return I;
+      CI = dyn_cast<CallInst>(BI->getPrevNode());
+      if (CI && CI->isMustTailCall())
+        return CI;
+      return I;
+    };
+
+    Builder.SetInsertPoint(AdjustMustTailCall(TI));
     return &Builder;
   }
 
@@ -54,11 +74,12 @@ IRBuilder<> *EscapeEnumerator::Next() {
     return nullptr;
 
   // Find all 'call' instructions that may throw.
+  // We cannot tranform calls with musttail tag.
   SmallVector<Instruction *, 16> Calls;
   for (BasicBlock &BB : F)
     for (Instruction &II : BB)
       if (CallInst *CI = dyn_cast<CallInst>(&II))
-        if (!CI->doesNotThrow())
+        if (!CI->doesNotThrow() && !CI->isMustTailCall())
           Calls.push_back(CI);
 
   if (Calls.empty())
diff --git a/llvm/test/Instrumentation/ThreadSanitizer/tsan_musttail.ll b/llvm/test/Instrumentation/ThreadSanitizer/tsan_musttail.ll
new file mode 100644 (file)
index 0000000..bb681f6
--- /dev/null
@@ -0,0 +1,30 @@
+; To test that __tsan_func_exit always happen before musttaill call and no exception handling code.
+; RUN: opt < %s -tsan -S | FileCheck %s
+
+define internal i32 @preallocated_musttail(i32* preallocated(i32) %p) sanitize_thread {
+  %rv = load i32, i32* %p
+  ret i32 %rv
+}
+
+define i32 @call_preallocated_musttail(i32* preallocated(i32) %a) sanitize_thread {
+  %r = musttail call i32 @preallocated_musttail(i32* preallocated(i32) %a)
+  ret i32 %r
+}
+
+; CHECK-LABEL:  define i32 @call_preallocated_musttail(i32* preallocated(i32) %a) 
+; CHECK:          call void @__tsan_func_exit()
+; CHECK-NEXT:     %r = musttail call i32 @preallocated_musttail(i32* preallocated(i32) %a)
+; CHECK-NEXT:     ret i32 %r
+
+
+define i32 @call_preallocated_musttail_cast(i32* preallocated(i32) %a) sanitize_thread {
+  %r = musttail call i32 @preallocated_musttail(i32* preallocated(i32) %a)
+  %t = bitcast i32 %r to i32
+  ret i32 %t
+}
+
+; CHECK-LABEL:  define i32 @call_preallocated_musttail_cast(i32* preallocated(i32) %a)
+; CHECK:          call void @__tsan_func_exit()
+; CHECK-NEXT:     %r = musttail call i32 @preallocated_musttail(i32* preallocated(i32) %a)
+; CHECK-NEXT:     %t = bitcast i32 %r to i32
+; CHECK-NEXT:     ret i32 %t