Support error handling in forked threads (#14523)
authorJames Sun <jamessun@fb.com>
Thu, 20 Dec 2018 02:51:41 +0000 (18:51 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 02:54:46 +0000 (18:54 -0800)
Summary:
Save error info in the future for parent thread to pick up. Throw the error
when the thread is the root thread.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14523

Differential Revision: D13251756

Pulled By: highker

fbshipit-source-id: b40f9a45665e1a934743f131ec5e8bad5622ce67

aten/src/ATen/core/ivalue.h
test/test_jit.py
torch/csrc/jit/interpreter.cpp

index 8b99ecc..6626c6c 100644 (file)
@@ -512,6 +512,19 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
   }
 
  public:
+  struct CAFFE2_API FutureError final : public std::exception {
+    FutureError(std::string&& error_msg_)
+        : error_msg(std::move(error_msg_)) {}
+
+    FutureError() = default;
+
+    const char* what() const noexcept override {
+      return error_msg.c_str();
+    }
+
+    std::string error_msg;
+  };
+
   /**
   * Wait on the future until it completes.
   */
@@ -552,18 +565,30 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
       value_ = std::move(value);
     }
 
-    // There is no need to protect callbacks anymore.
-    // Once completed_ is set to true, no one can add new callback to the list.
-    for (auto& callback : callbacks) {
-      callback();
+    fireCallbacks();
+  }
+
+  void markCompleted(FutureError&& error_) {
+    {
+      // This is not to protect completed_ but to create a barrier
+      // from possible addCallback() calls
+      std::unique_lock<std::mutex> lock(mutex_);
+      AT_ASSERT(!completed());
+      completed_ = true;
+      has_error = true;
+      error = std::move(error_);
     }
-    callbacks.clear();
+
+    fireCallbacks();
   }
 
   // Get the result of the current future.
   IValue value() {
     std::unique_lock<std::mutex> lock(mutex_);
     AT_ASSERT(completed());
+    if (has_error) {
+      throw error;
+    }
     return value_;
   }
 
@@ -593,10 +618,22 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
       const Future& v);
 
  private:
+  void fireCallbacks() {
+    AT_ASSERT(completed());
+    // There is no need to protect callbacks with the lock.
+    // Once completed_ is set to true, no one can add new callback to the list.
+    for (auto& callback : callbacks) {
+      callback();
+    }
+    callbacks.clear();
+  }
+
   std::mutex mutex_;
   IValue value_; // when finished the value
   std::atomic_bool completed_ = {false}; // is this future complete
   std::vector<std::function<void(void)>> callbacks;
+  bool has_error = false;
+  FutureError error;
 };
 
 #undef TORCH_FORALL_TAGS
index 5828801..428290b 100644 (file)
@@ -11329,6 +11329,37 @@ class TestAsync(JitTestCase):
         y = torch.neg(x)
         self.assertEqual(module(x), tuple([y, y, y, y, x, x]))
 
+    def test_async_script_error(self):
+        x = torch.rand(3, 4)
+
+        @torch.jit.script
+        def foo(x):
+            # error here
+            return x.t() + x
+
+        @torch.jit.script
+        def wait_script(x):
+            fut = torch.jit._fork(foo, x)
+            return torch.jit._wait(fut)
+
+        @torch.jit.script
+        def wait_script_nest(x):
+            fut = torch.jit._fork(wait_script, x)
+            return torch.jit._wait(fut)
+
+        # no future
+        error_msg = 'The size.*must match the size of tensor'
+        with self.assertRaisesRegex(Exception, error_msg):
+            foo(x)
+
+        # one future
+        with self.assertRaisesRegex(Exception, error_msg):
+            wait_script(x)
+
+        # two futures with a different error
+        x = torch.rand(3, 4, 5)
+        with self.assertRaisesRegex(Exception, 'expects a 2D tensor'):
+            wait_script_nest(x)
 
 for test in autograd_method_tests:
     add_autograd_test(*test)
index a8a44d8..2c45c19 100644 (file)
@@ -692,16 +692,21 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
           });
 
           return true;
-        } catch(std::exception & e) {
-          if (!instructions[pc].debug_location) {
-            throw;
-          }
-          auto msg = instructions[pc].debug_location->wrapException(e, "operation failed in interpreter");
-          if (dynamic_cast<JITException *>(&e)) {
-            throw JITException(msg);
+        } catch (Future::FutureError& e) {
+          // Error from the forked thread.
+          auto msg = e.error_msg; // copy the error for each callback
+          handleError(std::move(msg), false);
+          return false;
+        } catch (std::exception& e) {
+          // Error from the current thread
+          bool is_jit_exception = dynamic_cast<JITException*>(&e);
+          if (instructions[pc].debug_location) {
+            handleError(instructions[pc].debug_location->wrapException(
+                e, "operation failed in interpreter"), is_jit_exception);
           } else {
-            throw std::runtime_error(msg);
+            handleError(e.what(), is_jit_exception);
           }
+          return false;
         }
     }
     if (future) {
@@ -717,6 +722,16 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
     return false;
   }
 
+  void handleError(std::string&& error_msg, bool is_jit_exception) {
+    if (future) {
+      future->markCompleted(Future::FutureError(std::move(error_msg)));
+    } else if (is_jit_exception) {
+      throw JITException(std::move(error_msg));
+    } else {
+      throw std::runtime_error(std::move(error_msg));
+    }
+  }
+
  public:
   c10::intrusive_ptr<Future> getOrCreateFuture() {
     if (!future) {