From: James Sun Date: Thu, 20 Dec 2018 02:51:41 +0000 (-0800) Subject: Support error handling in forked threads (#14523) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2160 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=88bf683cbc1ff2f2655a00fc0ec1774844a704aa;p=platform%2Fupstream%2Fpytorch.git Support error handling in forked threads (#14523) Summary: Save error info in the future for parent thread to pick up. Throw the error when the thread is the root thread. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14523 Differential Revision: D13251756 Pulled By: highker fbshipit-source-id: b40f9a45665e1a934743f131ec5e8bad5622ce67 --- diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 8b99ecc..6626c6c 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -512,6 +512,19 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { } public: + struct CAFFE2_API FutureError final : public std::exception { + FutureError(std::string&& error_msg_) + : error_msg(std::move(error_msg_)) {} + + FutureError() = default; + + const char* what() const noexcept override { + return error_msg.c_str(); + } + + std::string error_msg; + }; + /** * Wait on the future until it completes. */ @@ -552,18 +565,30 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { value_ = std::move(value); } - // There is no need to protect callbacks anymore. - // Once completed_ is set to true, no one can add new callback to the list. - for (auto& callback : callbacks) { - callback(); + fireCallbacks(); + } + + void markCompleted(FutureError&& error_) { + { + // This is not to protect completed_ but to create a barrier + // from possible addCallback() calls + std::unique_lock lock(mutex_); + AT_ASSERT(!completed()); + completed_ = true; + has_error = true; + error = std::move(error_); } - callbacks.clear(); + + fireCallbacks(); } // Get the result of the current future. IValue value() { std::unique_lock lock(mutex_); AT_ASSERT(completed()); + if (has_error) { + throw error; + } return value_; } @@ -593,10 +618,22 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { const Future& v); private: + void fireCallbacks() { + AT_ASSERT(completed()); + // There is no need to protect callbacks with the lock. + // Once completed_ is set to true, no one can add new callback to the list. + for (auto& callback : callbacks) { + callback(); + } + callbacks.clear(); + } + std::mutex mutex_; IValue value_; // when finished the value std::atomic_bool completed_ = {false}; // is this future complete std::vector> callbacks; + bool has_error = false; + FutureError error; }; #undef TORCH_FORALL_TAGS diff --git a/test/test_jit.py b/test/test_jit.py index 5828801..428290b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -11329,6 +11329,37 @@ class TestAsync(JitTestCase): y = torch.neg(x) self.assertEqual(module(x), tuple([y, y, y, y, x, x])) + def test_async_script_error(self): + x = torch.rand(3, 4) + + @torch.jit.script + def foo(x): + # error here + return x.t() + x + + @torch.jit.script + def wait_script(x): + fut = torch.jit._fork(foo, x) + return torch.jit._wait(fut) + + @torch.jit.script + def wait_script_nest(x): + fut = torch.jit._fork(wait_script, x) + return torch.jit._wait(fut) + + # no future + error_msg = 'The size.*must match the size of tensor' + with self.assertRaisesRegex(Exception, error_msg): + foo(x) + + # one future + with self.assertRaisesRegex(Exception, error_msg): + wait_script(x) + + # two futures with a different error + x = torch.rand(3, 4, 5) + with self.assertRaisesRegex(Exception, 'expects a 2D tensor'): + wait_script_nest(x) for test in autograd_method_tests: add_autograd_test(*test) diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index a8a44d8..2c45c19 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -692,16 +692,21 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { }); return true; - } catch(std::exception & e) { - if (!instructions[pc].debug_location) { - throw; - } - auto msg = instructions[pc].debug_location->wrapException(e, "operation failed in interpreter"); - if (dynamic_cast(&e)) { - throw JITException(msg); + } catch (Future::FutureError& e) { + // Error from the forked thread. + auto msg = e.error_msg; // copy the error for each callback + handleError(std::move(msg), false); + return false; + } catch (std::exception& e) { + // Error from the current thread + bool is_jit_exception = dynamic_cast(&e); + if (instructions[pc].debug_location) { + handleError(instructions[pc].debug_location->wrapException( + e, "operation failed in interpreter"), is_jit_exception); } else { - throw std::runtime_error(msg); + handleError(e.what(), is_jit_exception); } + return false; } } if (future) { @@ -717,6 +722,16 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { return false; } + void handleError(std::string&& error_msg, bool is_jit_exception) { + if (future) { + future->markCompleted(Future::FutureError(std::move(error_msg))); + } else if (is_jit_exception) { + throw JITException(std::move(error_msg)); + } else { + throw std::runtime_error(std::move(error_msg)); + } + } + public: c10::intrusive_ptr getOrCreateFuture() { if (!future) {