}
public:
+ struct CAFFE2_API FutureError final : public std::exception {
+ FutureError(std::string&& error_msg_)
+ : error_msg(std::move(error_msg_)) {}
+
+ FutureError() = default;
+
+ const char* what() const noexcept override {
+ return error_msg.c_str();
+ }
+
+ std::string error_msg;
+ };
+
/**
* Wait on the future until it completes.
*/
value_ = std::move(value);
}
- // There is no need to protect callbacks anymore.
- // Once completed_ is set to true, no one can add new callback to the list.
- for (auto& callback : callbacks) {
- callback();
+ fireCallbacks();
+ }
+
+ void markCompleted(FutureError&& error_) {
+ {
+ // This is not to protect completed_ but to create a barrier
+ // from possible addCallback() calls
+ std::unique_lock<std::mutex> lock(mutex_);
+ AT_ASSERT(!completed());
+ completed_ = true;
+ has_error = true;
+ error = std::move(error_);
}
- callbacks.clear();
+
+ fireCallbacks();
}
// Get the result of the current future.
IValue value() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
+ if (has_error) {
+ throw error;
+ }
return value_;
}
const Future& v);
private:
+ void fireCallbacks() {
+ AT_ASSERT(completed());
+ // There is no need to protect callbacks with the lock.
+ // Once completed_ is set to true, no one can add new callback to the list.
+ for (auto& callback : callbacks) {
+ callback();
+ }
+ callbacks.clear();
+ }
+
std::mutex mutex_;
IValue value_; // when finished the value
std::atomic_bool completed_ = {false}; // is this future complete
std::vector<std::function<void(void)>> callbacks;
+ bool has_error = false;
+ FutureError error;
};
#undef TORCH_FORALL_TAGS
y = torch.neg(x)
self.assertEqual(module(x), tuple([y, y, y, y, x, x]))
+ def test_async_script_error(self):
+ x = torch.rand(3, 4)
+
+ @torch.jit.script
+ def foo(x):
+ # error here
+ return x.t() + x
+
+ @torch.jit.script
+ def wait_script(x):
+ fut = torch.jit._fork(foo, x)
+ return torch.jit._wait(fut)
+
+ @torch.jit.script
+ def wait_script_nest(x):
+ fut = torch.jit._fork(wait_script, x)
+ return torch.jit._wait(fut)
+
+ # no future
+ error_msg = 'The size.*must match the size of tensor'
+ with self.assertRaisesRegex(Exception, error_msg):
+ foo(x)
+
+ # one future
+ with self.assertRaisesRegex(Exception, error_msg):
+ wait_script(x)
+
+ # two futures with a different error
+ x = torch.rand(3, 4, 5)
+ with self.assertRaisesRegex(Exception, 'expects a 2D tensor'):
+ wait_script_nest(x)
for test in autograd_method_tests:
add_autograd_test(*test)
});
return true;
- } catch(std::exception & e) {
- if (!instructions[pc].debug_location) {
- throw;
- }
- auto msg = instructions[pc].debug_location->wrapException(e, "operation failed in interpreter");
- if (dynamic_cast<JITException *>(&e)) {
- throw JITException(msg);
+ } catch (Future::FutureError& e) {
+ // Error from the forked thread.
+ auto msg = e.error_msg; // copy the error for each callback
+ handleError(std::move(msg), false);
+ return false;
+ } catch (std::exception& e) {
+ // Error from the current thread
+ bool is_jit_exception = dynamic_cast<JITException*>(&e);
+ if (instructions[pc].debug_location) {
+ handleError(instructions[pc].debug_location->wrapException(
+ e, "operation failed in interpreter"), is_jit_exception);
} else {
- throw std::runtime_error(msg);
+ handleError(e.what(), is_jit_exception);
}
+ return false;
}
}
if (future) {
return false;
}
+ void handleError(std::string&& error_msg, bool is_jit_exception) {
+ if (future) {
+ future->markCompleted(Future::FutureError(std::move(error_msg)));
+ } else if (is_jit_exception) {
+ throw JITException(std::move(error_msg));
+ } else {
+ throw std::runtime_error(std::move(error_msg));
+ }
+ }
+
public:
c10::intrusive_ptr<Future> getOrCreateFuture() {
if (!future) {