Do not differentiage integers in the eager API.
authorAlexandre Passos <apassos@google.com>
Tue, 8 May 2018 21:42:35 +0000 (14:42 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 22:52:59 +0000 (15:52 -0700)
This is similar to the change made in:
https://github.com/tensorflow/tensorflow/commit/f63750645826df65b05cad505546a86f0e347674
for backpropagation during graph construction via tf.gradients()

PiperOrigin-RevId: 195878952

tensorflow/c/eager/tape.h
tensorflow/contrib/eager/python/tfe_test.py
tensorflow/python/eager/backprop.py
tensorflow/python/eager/backprop_test.py
tensorflow/python/eager/pywrap_tensor.cc
tensorflow/python/eager/pywrap_tensor.h
tensorflow/python/eager/pywrap_tfe_src.cc

index 8026076..e9ed339 100644 (file)
@@ -130,13 +130,15 @@ class GradientTape {
     }
   }
 
-  bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
+  bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
+                    gtl::ArraySlice<tensorflow::DataType> dtypes);
 
   void Watch(int64 tensor_id);
 
   void RecordOperation(const string& op_type,
                        gtl::ArraySlice<TapeTensor> output_tensors,
                        gtl::ArraySlice<int64> input_tensor_id,
+                       gtl::ArraySlice<tensorflow::DataType> input_dtypes,
                        BackwardFunction* backward_function,
                        const std::function<void()>& backward_function_deleter);
 
@@ -170,12 +172,30 @@ class GradientTape {
 
 // Template instantiations here
 
+inline bool IsDtypeTrainable(DataType dtype) {
+  switch (dtype) {
+    case DT_HALF:
+    case DT_BFLOAT16:
+    case DT_FLOAT:
+    case DT_DOUBLE:
+    case DT_COMPLEX64:
+    case DT_COMPLEX128:
+    case DT_RESOURCE:
+    case DT_VARIANT:
+      return true;
+    default:
+      return false;
+  }
+}
+
 template <typename Gradient, typename BackwardFunction>
 bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
-    gtl::ArraySlice<int64> tensor_ids) {
-  for (int64 i : tensor_ids) {
-    if (tensor_tape_.find(i) != tensor_tape_.end()) {
-      return true;
+    gtl::ArraySlice<int64> tensor_ids,
+    gtl::ArraySlice<tensorflow::DataType> dtypes) {
+  CHECK_EQ(tensor_ids.size(), dtypes.size());
+  for (int i = 0; i < tensor_ids.size(); ++i) {
+    if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
+      return IsDtypeTrainable(dtypes[i]);
     }
   }
   return false;
@@ -189,9 +209,11 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
 template <typename Gradient, typename BackwardFunction>
 void GradientTape<Gradient, BackwardFunction>::RecordOperation(
     const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
-    gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
+    gtl::ArraySlice<int64> input_tensor_id,
+    gtl::ArraySlice<tensorflow::DataType> input_dtypes,
+    BackwardFunction* backward_function,
     const std::function<void()>& backward_function_deleter) {
-  if (!ShouldRecord(input_tensor_id)) {
+  if (!ShouldRecord(input_tensor_id, input_dtypes)) {
     backward_function_deleter();
     return;
   }
index e80ccbb..db50b33 100644 (file)
@@ -57,7 +57,7 @@ class TFETest(test_util.TensorFlowTestCase):
       return math_ops.multiply(x, x)
 
     grad = tfe.gradients_function(square)
-    self.assertEquals([6], [x.numpy() for x in grad(3)])
+    self.assertEquals([6], [x.numpy() for x in grad(3.)])
 
   def testGradOfGrad(self):
 
@@ -66,7 +66,7 @@ class TFETest(test_util.TensorFlowTestCase):
 
     grad = tfe.gradients_function(square)
     gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
-    self.assertEquals([2], [x.numpy() for x in gradgrad(3)])
+    self.assertEquals([2], [x.numpy() for x in gradgrad(3.)])
 
   def testCustomGrad(self):
 
@@ -80,7 +80,7 @@ class TFETest(test_util.TensorFlowTestCase):
       return y, grad_fn
 
     grad = tfe.gradients_function(f)
-    self.assertEquals([12], [x.numpy() for x in grad(3)])
+    self.assertEquals([12], [x.numpy() for x in grad(3.)])
 
   def testGPU(self):
     if tfe.num_gpus() <= 0:
index d04b004..967c128 100644 (file)
@@ -358,6 +358,8 @@ def gradients_function(f, params=None):
   assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
   ```
 
+  Note that only tensors with real or complex dtypes are differentiable.
+
   Args:
    f: function to be differentiated. If `f` returns a scalar, this scalar will
      be differentiated. If `f` returns a tensor or list of tensors, by default
@@ -700,6 +702,9 @@ class GradientTape(object):
   dz_dx = g.gradient(z, x)  # 108.0 (4*x^3 at x = 3)
   dy_dx = g.gradient(y, x)  # 6.0
   del g  # Drop the reference to the tape
+  ```
+
+  Note that only tensors with real or complex dtypes are differentiable.
   """
 
   def __init__(self, persistent=False):
index 8d9959f..be67448 100644 (file)
@@ -124,6 +124,14 @@ class BackpropTest(test.TestCase):
     grad_fn = backprop.gradients_function(f)
     self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
 
+  def testGradientInteger(self):
+
+    def f(x):
+      return x + x
+
+    int_tensor = constant_op.constant(1)
+    self.assertEqual(backprop.gradients_function(f)(int_tensor)[0], None)
+
   def testErrors(self):
 
     @custom_gradient.custom_gradient
@@ -753,7 +761,7 @@ class BackpropTest(test.TestCase):
       return result, grad
 
     x = resource_variable_ops.ResourceVariable(
-        initial_value=3, name='X.' + self.id())
+        initial_value=3., name='X.' + self.id())
 
     def f():
       return my_square(x)
index b5b4e39..b3aadd5 100644 (file)
@@ -650,6 +650,12 @@ tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
   return reinterpret_cast<const EagerTensor*>(tensor)->id;
 }
 
+tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
+  CHECK(EagerTensor_CheckExact(tensor));
+  return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
+      reinterpret_cast<const EagerTensor*>(tensor)->handle));
+}
+
 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
   if (!PyType_Check(base_class)) {
     PyErr_SetString(
index 63ab1ed..88982b0 100644 (file)
@@ -21,6 +21,7 @@ limitations under the License.
 
 bool EagerTensor_CheckExact(const PyObject* o);
 tensorflow::int64 EagerTensor_id(const PyObject* tensor);
+tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
 
 namespace tensorflow {
 TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
index 4ecba1a..48a5b21 100644 (file)
@@ -843,6 +843,24 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
   return id;
 }
 
+static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
+  if (EagerTensor_CheckExact(tensor)) {
+    return EagerTensor_dtype(tensor);
+  }
+  PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
+  if (dtype_field == nullptr) {
+    return tensorflow::DT_INVALID;
+  }
+  PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum");
+  Py_DECREF(dtype_field);
+  if (dtype_field == nullptr) {
+    return tensorflow::DT_INVALID;
+  }
+  tensorflow::int64 id = MakeInt(enum_field);
+  Py_DECREF(enum_field);
+  return static_cast<tensorflow::DataType>(id);
+}
+
 class GradientTape
     : public tensorflow::eager::GradientTape<PyObject, PyObject> {
  public:
@@ -1053,15 +1071,18 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
   // TODO(apassos) consider not building a list and changing the API to check
   // each tensor individually.
   std::vector<tensorflow::int64> tensor_ids;
+  std::vector<tensorflow::DataType> dtypes;
   tensor_ids.reserve(len);
+  dtypes.reserve(len);
   for (int i = 0; i < len; ++i) {
     PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
     tensor_ids.push_back(FastTensorId(item));
+    dtypes.push_back(FastTensorDtype(item));
   }
   Py_DECREF(seq);
   auto tape_set = *tape_set_ptr;
   for (TFE_Py_Tape* tape : tape_set) {
-    if (tape->tape->ShouldRecord(tensor_ids)) {
+    if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
       Py_RETURN_TRUE;
     }
   }
@@ -1169,9 +1190,27 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
 }
 
 namespace {
-void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
-                            const std::vector<tensorflow::int64>& input_ids,
-                            PyObject* backward_function) {
+std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
+  PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+  if (seq == nullptr) {
+    return {};
+  }
+  int len = PySequence_Fast_GET_SIZE(seq);
+  std::vector<tensorflow::DataType> list;
+  list.reserve(len);
+  for (int i = 0; i < len; ++i) {
+    PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
+    list.push_back(FastTensorDtype(tensor));
+  }
+  Py_DECREF(seq);
+  return list;
+}
+
+void TapeSetRecordOperation(
+    PyObject* op_type, PyObject* output_tensors,
+    const std::vector<tensorflow::int64>& input_ids,
+    const std::vector<tensorflow::DataType>& input_dtypes,
+    PyObject* backward_function) {
   std::vector<tensorflow::eager::TapeTensor> output_info;
   PyObject* seq = PySequence_Fast(output_tensors,
                                   "expected a sequence of integer tensor ids");
@@ -1206,7 +1245,7 @@ void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
   for (TFE_Py_Tape* tape : SafeTapeSet()) {
     Py_INCREF(backward_function);
     tape->tape->RecordOperation(
-        op_type_str, output_info, input_ids, backward_function,
+        op_type_str, output_info, input_ids, input_dtypes, backward_function,
         [backward_function]() { Py_DECREF(backward_function); });
   }
 }
@@ -1221,7 +1260,11 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
   if (PyErr_Occurred()) return;
 
-  TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function);
+  std::vector<tensorflow::DataType> input_dtypes =
+      MakeTensorDtypeList(input_tensors);
+  if (PyErr_Occurred()) return;
+  TapeSetRecordOperation(op_type, output_tensors, input_ids, input_dtypes,
+                         backward_function);
 }
 
 void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
@@ -1710,10 +1753,12 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
                          PyObject* results, PyObject* name) {
   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
   if (PyErr_Occurred()) return nullptr;
+  std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
+  if (PyErr_Occurred()) return nullptr;
 
   bool should_record = false;
   for (TFE_Py_Tape* tape : SafeTapeSet()) {
-    if (tape->tape->ShouldRecord(input_ids)) {
+    if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
       should_record = true;
       break;
     }
@@ -1744,7 +1789,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
   Py_DECREF(callback_args);
   if (backward_function == nullptr) return nullptr;
 
-  TapeSetRecordOperation(op_name, results, input_ids, backward_function);
+  TapeSetRecordOperation(op_name, results, input_ids, input_dtypes,
+                         backward_function);
 
   Py_DECREF(backward_function);