from __future__ import division
from __future__ import print_function
+import re
+
import numpy as np
from six.moves import queue
from six.moves import xrange # pylint: disable=redefined-builtin
def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
- def raise_exception():
+ def inner_exception():
raise py_exp("blah") # pylint: disable=not-callable
+ def raise_exception():
+ inner_exception()
+
+ expected_regexp = r": blah.*" # Error at the top
+ expected_regexp += r"in raise_exception.*" # Stacktrace outer
+ expected_regexp += r"in inner_exception.*" # Stacktrace inner
+ expected_regexp += r": blah" # Stacktrace of raise
+ def expected_error_check(exception):
+ return re.search(expected_regexp, str(exception), re.DOTALL)
+
if eager:
if context.executing_eagerly():
- with self.assertRaisesRegexp(tf_exp, "blah"):
+ with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
f = script_ops.eager_py_func(raise_exception, [], [])
return
else:
f = script_ops.py_func(raise_exception, [], [])
with self.test_session():
- with self.assertRaisesRegexp(tf_exp, "blah"):
+ with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
self.evaluate(f)
def testExceptionHandling(self):
} // end namespace
+// Returns a PyObject containing a string, or null
+void TryAppendTraceback(PyObject* ptype, PyObject* pvalue, PyObject* ptraceback,
+ string* out) {
+ // The "traceback" module is assumed to be imported already by script_ops.py.
+ PyObject* tb_module = PyImport_AddModule("traceback");
+
+ if (!tb_module) {
+ return;
+ }
+
+ PyObject* format_exception =
+ PyObject_GetAttrString(tb_module, "format_exception");
+
+ if (!format_exception) {
+ return;
+ }
+
+ if (!PyCallable_Check(format_exception)) {
+ Py_DECREF(format_exception);
+ return;
+ }
+
+ PyObject* ret_val = PyObject_CallFunctionObjArgs(format_exception, ptype,
+ pvalue, ptraceback, nullptr);
+ Py_DECREF(format_exception);
+
+ if (!ret_val) {
+ return;
+ }
+
+ if (!PyList_Check(ret_val)) {
+ Py_DECREF(ret_val);
+ return;
+ }
+
+ Py_ssize_t n = PyList_GET_SIZE(ret_val);
+ for (Py_ssize_t i = 0; i < n; ++i) {
+ PyObject* v = PyList_GET_ITEM(ret_val, i);
+#if PY_MAJOR_VERSION < 3
+ strings::StrAppend(out, PyString_AS_STRING(v), "\n");
+#else
+ strings::StrAppend(out, PyUnicode_AsUTF8(v), "\n");
+#endif
+ }
+
+ // Iterate through ret_val.
+ Py_DECREF(ret_val);
+}
+
string PyExceptionFetch() {
CHECK(PyErr_Occurred())
<< "Must only call PyExceptionFetch after an exception.";
string err = ClassName(ptype);
if (pvalue) {
PyObject* str = PyObject_Str(pvalue);
+
if (str) {
#if PY_MAJOR_VERSION < 3
- strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
+ strings::StrAppend(&err, ": ", PyString_AS_STRING(str), "\n");
#else
- strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
+ strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str), "\n");
#endif
Py_DECREF(str);
+ } else {
+ strings::StrAppend(&err, "(unknown error message)\n");
}
+
+ TryAppendTraceback(ptype, pvalue, ptraceback, &err);
+
Py_DECREF(pvalue);
}
Py_DECREF(ptype);