return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight)
def checkScriptRaisesRegex(self, script, inputs, exception, regex,
- outputs=None, capture_output=False, profiling=ProfilingMode.PROFILING):
+ name=None, outputs=None, capture_output=False,
+ frames_up=1, profiling=ProfilingMode.PROFILING):
"""
Checks that a given function will throw the correct exception,
- when executed with normal python, the string frontend, and the AST frontend
+ when executed with normal python, the string frontend, and the
+ AST frontend. Logic taken from `checkScript` (see comments there
+ for details)
"""
-
with enable_profiling_mode_for_profiling_tests():
- # normal python
+ # Normal Python
with self.assertRaisesRegex(exception, regex):
- script(*inputs)
- # string frontend
+ if isinstance(script, str):
+ frame = self.get_frame_vars(frames_up)
+ the_locals: Dict[str, Any] = {}
+ execWrapper(script, glob=frame, loc=the_locals)
+ frame.update(the_locals)
+
+ python_fn = frame[name]
+ else:
+ python_fn = script
+
+ python_fn(*inputs)
+
+ # String frontend
with self.assertRaisesRegex(exception, regex):
- source = textwrap.dedent(inspect.getsource(script))
- cu = torch.jit.CompilationUnit(source)
- ge = getattr(cu, script.__name__)
- # profiling run
+ if isinstance(script, str):
+ cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
+ string_frontend = getattr(cu, name)
+ else:
+ source = textwrap.dedent(inspect.getsource(script))
+ cu = torch.jit.CompilationUnit(source, _frames_up=frames_up)
+ string_frontend = getattr(cu, script.__name__)
+
with self.assertRaisesRegex(exception, regex):
- ge(*inputs)
+ string_frontend(*inputs)
# optimized run
- ge(*inputs)
- # python AST frontend
- with self.assertRaisesRegex(exception, regex):
- ge = torch.jit.script(script)
- # profiling run
+ string_frontend(*inputs)
+
+ # Python AST frontend
+ if not isinstance(script, str):
with self.assertRaisesRegex(exception, regex):
+ ge = torch.jit.script(python_fn)
+ # profiling run
+ with self.assertRaisesRegex(exception, regex):
+ ge(*inputs)
+ # optimized run
ge(*inputs)
- # optimized run
- ge(*inputs)
-
def checkBailouts(self, model, inputs, expected):
state = model.get_debug_state()