py_func attaches full stack traces when an error is raised.
authorEugene Brevdo <ebrevdo@google.com>
Wed, 7 Mar 2018 22:53:49 +0000 (14:53 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Mar 2018 22:57:37 +0000 (14:57 -0800)
This should help debugging errors that occur inside a py_func.

PiperOrigin-RevId: 188238495

tensorflow/python/kernel_tests/py_func_test.py
tensorflow/python/lib/core/py_util.cc
tensorflow/python/ops/script_ops.py

index 63203a0..3614280 100644 (file)
@@ -19,6 +19,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import re
+
 import numpy as np
 from six.moves import queue
 from six.moves import xrange  # pylint: disable=redefined-builtin
@@ -356,12 +358,22 @@ class PyFuncTest(test.TestCase):
 
   def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
 
-    def raise_exception():
+    def inner_exception():
       raise py_exp("blah")  # pylint: disable=not-callable
 
+    def raise_exception():
+      inner_exception()
+
+    expected_regexp = r": blah.*"               # Error at the top
+    expected_regexp += r"in raise_exception.*"  # Stacktrace outer
+    expected_regexp += r"in inner_exception.*"  # Stacktrace inner
+    expected_regexp += r": blah"                # Stacktrace of raise
+    def expected_error_check(exception):
+      return re.search(expected_regexp, str(exception), re.DOTALL)
+
     if eager:
       if context.executing_eagerly():
-        with self.assertRaisesRegexp(tf_exp, "blah"):
+        with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
           f = script_ops.eager_py_func(raise_exception, [], [])
         return
       else:
@@ -370,7 +382,7 @@ class PyFuncTest(test.TestCase):
       f = script_ops.py_func(raise_exception, [], [])
 
     with self.test_session():
-      with self.assertRaisesRegexp(tf_exp, "blah"):
+      with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
         self.evaluate(f)
 
   def testExceptionHandling(self):
index 2635694..00cbf0c 100644 (file)
@@ -41,6 +41,55 @@ const char* ClassName(PyObject* py) {
 
 }  // end namespace
 
+// Returns a PyObject containing a string, or null
+void TryAppendTraceback(PyObject* ptype, PyObject* pvalue, PyObject* ptraceback,
+                        string* out) {
+  // The "traceback" module is assumed to be imported already by script_ops.py.
+  PyObject* tb_module = PyImport_AddModule("traceback");
+
+  if (!tb_module) {
+    return;
+  }
+
+  PyObject* format_exception =
+      PyObject_GetAttrString(tb_module, "format_exception");
+
+  if (!format_exception) {
+    return;
+  }
+
+  if (!PyCallable_Check(format_exception)) {
+    Py_DECREF(format_exception);
+    return;
+  }
+
+  PyObject* ret_val = PyObject_CallFunctionObjArgs(format_exception, ptype,
+                                                   pvalue, ptraceback, nullptr);
+  Py_DECREF(format_exception);
+
+  if (!ret_val) {
+    return;
+  }
+
+  if (!PyList_Check(ret_val)) {
+    Py_DECREF(ret_val);
+    return;
+  }
+
+  Py_ssize_t n = PyList_GET_SIZE(ret_val);
+  for (Py_ssize_t i = 0; i < n; ++i) {
+    PyObject* v = PyList_GET_ITEM(ret_val, i);
+#if PY_MAJOR_VERSION < 3
+    strings::StrAppend(out, PyString_AS_STRING(v), "\n");
+#else
+    strings::StrAppend(out, PyUnicode_AsUTF8(v), "\n");
+#endif
+  }
+
+  // Iterate through ret_val.
+  Py_DECREF(ret_val);
+}
+
 string PyExceptionFetch() {
   CHECK(PyErr_Occurred())
       << "Must only call PyExceptionFetch after an exception.";
@@ -52,14 +101,20 @@ string PyExceptionFetch() {
   string err = ClassName(ptype);
   if (pvalue) {
     PyObject* str = PyObject_Str(pvalue);
+
     if (str) {
 #if PY_MAJOR_VERSION < 3
-      strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
+      strings::StrAppend(&err, ": ", PyString_AS_STRING(str), "\n");
 #else
-      strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
+      strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str), "\n");
 #endif
       Py_DECREF(str);
+    } else {
+      strings::StrAppend(&err, "(unknown error message)\n");
     }
+
+    TryAppendTraceback(ptype, pvalue, ptraceback, &err);
+
     Py_DECREF(pvalue);
   }
   Py_DECREF(ptype);
index 529eebe..fb59bbb 100644 (file)
@@ -25,6 +25,9 @@ from __future__ import print_function
 
 import threading
 
+# Used by py_util.cc to get tracebacks.
+import traceback  # pylint: disable=unused-import
+
 import numpy as np
 import six