Allow uncompiled strings as input to `checkScriptRaisesRegex` (#63901)
authorAnsley Ussery <ansley@fb.com>
Thu, 26 Aug 2021 19:14:32 +0000 (12:14 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 19:17:07 +0000 (12:17 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63901

cc gmagogsfm

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D30579472

Pulled By: ansley

fbshipit-source-id: 59ee09c1f25278d4f6e51f626588251bd095c6ea

test/jit/test_jit_utils.py
torch/testing/_internal/jit_utils.py

index 11d974b..b344f82 100644 (file)
@@ -77,3 +77,18 @@ class TestJitUtils(JitTestCase):
         self.assertEqual(
             [],
             torch._jit_internal.get_callable_argument_names(fn_hybrid_args))
+
+    def test_checkscriptassertraisesregex(self):
+        def fn():
+            tup = (1, 2)
+            return tup[2]
+
+        self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn")
+
+        s = dedent("""
+        def fn():
+            tup = (1, 2)
+            return tup[2]
+        """)
+
+        self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
index 50d8dac..4c521a8 100644 (file)
@@ -375,35 +375,53 @@ class JitTestCase(JitCommonTestCase):
         return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight)
 
     def checkScriptRaisesRegex(self, script, inputs, exception, regex,
-                               outputs=None, capture_output=False, profiling=ProfilingMode.PROFILING):
+                               name=None, outputs=None, capture_output=False,
+                               frames_up=1, profiling=ProfilingMode.PROFILING):
         """
         Checks that a given function will throw the correct exception,
-        when executed with normal python, the string frontend, and the AST frontend
+        when executed with normal python, the string frontend, and the
+        AST frontend. Logic taken from `checkScript` (see comments there
+        for details)
         """
-
         with enable_profiling_mode_for_profiling_tests():
-            # normal python
+            # Normal Python
             with self.assertRaisesRegex(exception, regex):
-                script(*inputs)
-            # string frontend
+                if isinstance(script, str):
+                    frame = self.get_frame_vars(frames_up)
+                    the_locals: Dict[str, Any] = {}
+                    execWrapper(script, glob=frame, loc=the_locals)
+                    frame.update(the_locals)
+
+                    python_fn = frame[name]
+                else:
+                    python_fn = script
+
+                python_fn(*inputs)
+
+            # String frontend
             with self.assertRaisesRegex(exception, regex):
-                source = textwrap.dedent(inspect.getsource(script))
-                cu = torch.jit.CompilationUnit(source)
-                ge = getattr(cu, script.__name__)
-                # profiling run
+                if isinstance(script, str):
+                    cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
+                    string_frontend = getattr(cu, name)
+                else:
+                    source = textwrap.dedent(inspect.getsource(script))
+                    cu = torch.jit.CompilationUnit(source, _frames_up=frames_up)
+                    string_frontend = getattr(cu, script.__name__)
+
                 with self.assertRaisesRegex(exception, regex):
-                    ge(*inputs)
+                    string_frontend(*inputs)
                 # optimized run
-                ge(*inputs)
-            # python AST frontend
-            with self.assertRaisesRegex(exception, regex):
-                ge = torch.jit.script(script)
-                # profiling run
+                string_frontend(*inputs)
+
+            # Python AST frontend
+            if not isinstance(script, str):
                 with self.assertRaisesRegex(exception, regex):
+                    ge = torch.jit.script(python_fn)
+                    # profiling run
+                    with self.assertRaisesRegex(exception, regex):
+                        ge(*inputs)
+                    # optimized run
                     ge(*inputs)
-                # optimized run
-                ge(*inputs)
-
 
     def checkBailouts(self, model, inputs, expected):
         state = model.get_debug_state()