[JIT] Set future's error to current exception as is when `--torch_jit_enable_rethrow_...
authorDon Jang <djang@fb.com>
Tue, 17 Aug 2021 00:30:26 +0000 (17:30 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 00:32:13 +0000 (17:32 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63348

This change addresses singlaiiit's comment on D30241792 (https://github.com/pytorch/pytorch/commit/61b49c8e41a2faf7fd40278ca72616c5d92963cb), which makes the JIT interpreter's behavior consistent between `future` is set and not.

Test Plan: Enhanced `EnableRethrowCaughtExceptionTest.EnableRethrowCaughtExceptionTestRethrowsCaughtException` to cover the modified code path.

Reviewed By: singlaiiit

Differential Revision: D30347782

fbshipit-source-id: 79ce57283154ca4372e5341217d942398db21ac8

test/cpp/jit/test_interpreter.cpp
torch/csrc/jit/runtime/interpreter.cpp

index 2ba2fba..a241891 100644 (file)
@@ -265,6 +265,21 @@ graph(%0 : Tensor,
         "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1");
   }
   EXPECT_TRUE(exception_handled);
+
+  FLAGS_torch_jit_enable_rethrow_caught_exception = true;
+  c10::intrusive_ptr<Future> future = interp.runAsync(stack);
+  future->wait();
+  ASSERT_TRUE(future->completed());
+  ASSERT_TRUE(future->hasError());
+  try {
+    std::rethrow_exception(future->exception_ptr());
+  } catch (c10::Error& e) {
+    std::string exception_msg = e.what_without_backtrace();
+    EXPECT_STREQ(
+        exception_msg.c_str(),
+        "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1");
+  }
+
   FLAGS_torch_jit_enable_rethrow_caught_exception = original_flag_value;
 }
 
index a095e4a..be2019e 100644 (file)
@@ -720,7 +720,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
       }
       if (FLAGS_torch_jit_enable_rethrow_caught_exception) {
         if (future_) {
-          future_->setError(std::make_exception_ptr(e));
+          future_->setError(std::current_exception());
+          return false;
         }
         throw;
       }