Eager/C: Add a TF_Status argument to a couple of functions.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Feb 2018 23:06:46 +0000 (15:06 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Feb 2018 23:10:23 +0000 (15:10 -0800)
PiperOrigin-RevId: 186829318

tensorflow/c/eager/c_api.cc
tensorflow/c/eager/c_api.h
tensorflow/c/eager/c_api_test.cc
tensorflow/python/eager/pywrap_tensor.cc

index cc318c3..f615e3f 100644 (file)
@@ -154,16 +154,22 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
   return static_cast<TF_DataType>(h->t.dtype());
 }
 
-int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); }
+int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
+  status->status = tensorflow::Status::OK();
+  return h->t.dims();
+}
 
-int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) {
+int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
+                            TF_Status* status) {
+  status->status = tensorflow::Status::OK();
   return h->t.dim_size(dim_index);
 }
 
-const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) {
+const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
   // TODO(apassos) this will be potentially incorrect in the distributed case as
   // our local device will have a name which depends on the ClusterSpec and
   // hence will require the context to resolve.
+  status->status = tensorflow::Status::OK();
   return (h->d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
                            : h->d->name().c_str();
 }
index 7a321b5..90cfb75 100644 (file)
@@ -119,11 +119,13 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
                                                             TF_Status* status);
 TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
 TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
-TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
+TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h,
+                                                  TF_Status* status);
 TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
-                                                  int dim_index);
+                                                  int dim_index,
+                                                  TF_Status* status);
 TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
-    TFE_TensorHandle* h);
+    TFE_TensorHandle* h, TF_Status* status);
 TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
                                                          TF_Status* status);
 
index 4a3ecbc..00fb7e6 100644 (file)
@@ -932,7 +932,8 @@ TEST(CAPI, Variables) {
   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   ASSERT_EQ(1, num_retvals);
   EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle));
-  EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle));
+  EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle, status));
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   float value = 0.0f;
   TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status);
   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@@ -974,7 +975,8 @@ void BM_ReadVariable(int iters) {
     CHECK_EQ(1, num_retvals);
     CHECK(h);
     CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
-    CHECK_EQ(0, TFE_TensorHandleNumDims(h));
+    CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
+    CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
     h = nullptr;
   }
   tensorflow::testing::StopTiming();
index 6fa0765..3ec2109 100644 (file)
@@ -185,6 +185,12 @@ typedef struct EagerTensor {
 
   // This stores `_keras_mask` object and is set by Tensorflow layers.
   PyObject* keras_mask;
+
+  // We store a status object here as an optimization to avoid allocating a new
+  // Status objects on different functions that operate on EagerTensor and need
+  // to use a TF_Status object. However note that accesses to `status` are not
+  // thread-safe.
+  TF_Status* status;
 } EagerTensor;
 
 // tp_init for EagerTensor.
@@ -195,6 +201,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
   self->handle_data = Py_None;
   Py_INCREF(Py_None);
   self->keras_mask = Py_None;
+  self->status = TF_NewStatus();
   PyObject* value;
   PyObject* context = nullptr;
   PyObject* device = nullptr;
@@ -269,17 +276,17 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
   }
   TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
   if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
-    auto out_status = tensorflow::make_safe(TF_NewStatus());
     handle = tensorflow::make_safe(
         EagerCast(GetContext(context), handle.get(), handle_dtype,
-                  static_cast<TF_DataType>(desired_dtype), out_status.get()));
-    if (TF_GetCode(out_status.get()) != TF_OK) {
-      PyErr_SetString(
-          PyExc_ValueError,
-          tensorflow::strings::StrCat("Error while casting from DataType ",
-                                      handle_dtype, " to ", desired_dtype, ". ",
-                                      TF_Message(out_status.get()))
-              .c_str());
+                  static_cast<TF_DataType>(desired_dtype), self->status));
+    if (TF_GetCode(self->status) != TF_OK) {
+      PyErr_SetString(PyExc_ValueError,
+                      tensorflow::strings::StrCat(
+                          "Error while casting from DataType ", handle_dtype,
+                          " to ", desired_dtype, ". ", TF_Message(self->status))
+                          .c_str());
+      // Cleanup self->status before returning.
+      TF_SetStatus(self->status, TF_OK, "");
       return -1;
     }
     handle_dtype = TFE_TensorHandleDataType(handle.get());
@@ -323,6 +330,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
 
 // tp_dealloc for EagerTensor.
 void EagerTensor_dealloc(EagerTensor* self) {
+  TF_DeleteStatus(self->status);
   Py_DECREF(self->handle_data);
   Py_DECREF(self->keras_mask);
   TFE_DeleteTensorHandle(self->handle);
@@ -348,12 +356,21 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
 // Getter for `_shape_tuple`.
 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
   auto handle = self->handle;
-  int n = TFE_TensorHandleNumDims(handle);
+  int n = TFE_TensorHandleNumDims(handle, self->status);
+  if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+    // Cleanup self->status before returning.
+    TF_SetStatus(self->status, TF_OK, "");
+    return nullptr;
+  }
   PyObject* shape = PyTuple_New(n);
   if (PyErr_Occurred()) return nullptr;
   for (int i = 0; i < n; ++i) {
-    PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i));
-    if (dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
+    PyObject* dim =
+        PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
+    if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError) ||
+        dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
+      // Cleanup self->status before returning.
+      TF_SetStatus(self->status, TF_OK, "");
       Py_DECREF(shape);
       if (dim != nullptr) Py_DECREF(dim);
       PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
@@ -365,10 +382,16 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
 
 // Getter for `_rank`.
 static PyObject* EagerTensor_rank(EagerTensor* self) {
+  int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
+  if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+    // Cleanup self->status before returning.
+    TF_SetStatus(self->status, TF_OK, "");
+    return nullptr;
+  }
 #if PY_MAJOR_VERSION < 3
-  return PyInt_FromLong(TFE_TensorHandleNumDims(self->handle));
+  return PyInt_FromLong(num_dims);
 #else
-  return PyLong_FromLong(TFE_TensorHandleNumDims(self->handle));
+  return PyLong_FromLong(num_dims);
 #endif
 }
 
@@ -437,10 +460,16 @@ static PyObject* EagerTensor_numpy(EagerTensor* self) {
 
 // Getter `device`.
 static PyObject* EagerTensor_device(EagerTensor* self) {
+  const char* device = TFE_TensorHandleDeviceName(self->handle, self->status);
+  if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+    // Cleanup self->status before returning.
+    TF_SetStatus(self->status, TF_OK, "");
+    return nullptr;
+  }
 #if PY_MAJOR_VERSION >= 3
-  return PyUnicode_FromString(TFE_TensorHandleDeviceName(self->handle));
+  return PyUnicode_FromString(device);
 #else
-  return PyBytes_FromString(TFE_TensorHandleDeviceName(self->handle));
+  return PyBytes_FromString(device);
 #endif
 }
 
@@ -576,6 +605,7 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
     Py_INCREF(Py_None);
     t->keras_mask = Py_None;
     t->handle = handle;
+    t->status = TF_NewStatus();
   }
   return reinterpret_cast<PyObject*>(t);
 }
@@ -673,6 +703,7 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) {
   auto tensor = tensorflow::make_safe(TF_AllocateTensor(
       TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
   int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
+  auto status = tensorflow::make_safe(TF_NewStatus());
   for (Py_ssize_t i = 0; i < num_tensors; ++i) {
     PyObject* tensor_obj = PyList_GET_ITEM(tensor_list, i);
     if (!EagerTensor_CheckExact(tensor_obj)) {
@@ -687,21 +718,27 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) {
 
     EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
     TFE_TensorHandle* handle = t->handle;
-    if (slice_dim >= TFE_TensorHandleNumDims(handle)) {
-      PyErr_SetString(PyExc_IndexError,
-                      tensorflow::strings::StrCat(
-                          "Slice dimension (", slice_dim,
-                          ") must be smaller than rank of all "
-                          "tensors, but tensor at index ",
-                          i, " has rank ", TFE_TensorHandleNumDims(handle))
-                          .c_str());
+    int num_dims = TFE_TensorHandleNumDims(handle, status.get());
+    if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
+      return nullptr;
+    }
+    if (slice_dim >= num_dims) {
+      PyErr_SetString(
+          PyExc_IndexError,
+          tensorflow::strings::StrCat("Slice dimension (", slice_dim,
+                                      ") must be smaller than rank of all "
+                                      "tensors, but tensor at index ",
+                                      i, " has rank ", num_dims)
+              .c_str());
+      return nullptr;
+    }
+    int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
+    if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
       return nullptr;
     }
-    int64_t dim = TFE_TensorHandleDim(handle, slice_dim);
     data[i] = dim;
   }
 
-  auto status = tensorflow::make_safe(TF_NewStatus());
   TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
   if (TF_GetCode(status.get()) != TF_OK) {
     PyErr_SetString(