Fix memory leak when going from the fast path to the slow path in eager
authorAkshay Modi <nareshmodi@google.com>
Tue, 22 May 2018 19:46:30 +0000 (12:46 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 19:49:03 +0000 (12:49 -0700)
Fixes #19385

PiperOrigin-RevId: 197607384

tensorflow/python/eager/pywrap_tfe_src.cc

index 62deb41..9885b3d 100644 (file)
@@ -49,8 +49,7 @@ using AttrToInputsMap =
     tensorflow::gtl::FlatMap<string,
                              tensorflow::gtl::InlinedVector<InputInfo, 4>>;
 
-tensorflow::mutex all_attr_to_input_maps_lock(
-    tensorflow::LINKER_INITIALIZED);
+tensorflow::mutex all_attr_to_input_maps_lock(tensorflow::LINKER_INITIALIZED);
 tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
   static auto* all_attr_to_input_maps =
       new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
@@ -754,7 +753,7 @@ PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e) {
 
 void RaiseFallbackException(const char* message) {
   if (fallback_exception_class != nullptr) {
-    PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message));
+    PyErr_SetString(fallback_exception_class, message);
     return;
   }
 
@@ -772,8 +771,9 @@ int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
   if (exception == nullptr) {
     tensorflow::mutex_lock l(exception_class_mutex);
     if (exception_class != nullptr) {
-      PyErr_SetObject(exception_class,
-                      Py_BuildValue("si", msg, TF_GetCode(status)));
+      tensorflow::Safe_PyObjectPtr val(
+          Py_BuildValue("si", msg, TF_GetCode(status)));
+      PyErr_SetObject(exception_class, val.get());
       return -1;
     } else {
       exception = PyExc_RuntimeError;
@@ -791,7 +791,8 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
   if (exception == nullptr) {
     tensorflow::mutex_lock l(exception_class_mutex);
     if (exception_class != nullptr) {
-      PyErr_SetObject(exception_class, Py_BuildValue("si", msg, status.code()));
+      tensorflow::Safe_PyObjectPtr val(Py_BuildValue("si", msg, status.code()));
+      PyErr_SetObject(exception_class, val.get());
       return -1;
     } else {
       exception = PyExc_RuntimeError;