Record gradient in C
authorAkshay Modi <nareshmodi@google.com>
Wed, 21 Feb 2018 22:42:11 +0000 (14:42 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Feb 2018 22:48:12 +0000 (14:48 -0800)
PiperOrigin-RevId: 186522240

tensorflow/python/eager/backprop.py
tensorflow/python/eager/pywrap_tfe.h
tensorflow/python/eager/pywrap_tfe_src.cc
tensorflow/python/pywrap_tfe.i

index d8e13d7..5505661 100644 (file)
@@ -137,110 +137,6 @@ _gradient_functions_lock = threading.Lock()
 _tracing = False
 
 
-# TODO(apassos) replace this with a mechanism which can happen at the op
-# gradient function registration site, to be less error-prone
-# TODO(apassos) add ops other than those in nn_grad and math_grad
-_ops_which_dont_need_outputs = set([
-    "Identity",
-    "MatMul",
-    "Conv2DBackpropInput",
-    "Conv2DBackpropFilter",
-    "Conv3D",
-    "Conv3DBackpropInputV2",
-    "AvgPool3D",
-    "AvgPool3DGrad",
-    "MaxPool3D",
-    "MaxPool3DGrad",
-    "MaxPool3DGradGrad",
-    "BiasAdd",
-    "BiasAddV1",
-    "BiasAddGrad",
-    "Relu6",
-    "Softplus",
-    "SoftplusGrad",
-    "Softsign",
-    "ReluGrad",
-    "Conv2D",
-    "DepthwiseConv2dNative",
-    "Dilation2D",
-    "AvgPool",
-    "AvgPoolGrad",
-    "BatchNormWithGlobalNormalization",
-    "L2Loss",
-    "Sum",
-    "Prod",
-    "SegmentSum",
-    "SegmentMean",
-    "SparseSegmentSum",
-    "SparseSegmentMean",
-    "SparseSegmentSqrtN",
-    "SegmentMin",
-    "SegmentMax",
-    "UnsortedSegmentSum",
-    "UnsortedSegmentMax",
-    "Abs",
-    "Neg",
-    "ReciprocalGrad",
-    "Square",
-    "Expm1",
-    "Log",
-    "Log1p",
-    "TanhGrad",
-    "SigmoidGrad",
-    "Sign",
-    "Sin",
-    "Cos",
-    "Tan",
-    "Add",
-    "Sub",
-    "Mul",
-    "Div",
-    "RealDiv",
-    "Maximum",
-    "Minimum",
-    "SquaredDifference",
-    "Select",
-    "SparseMatMul",
-    "BatchMatMul",
-    "Complex",
-    "Real",
-    "Imag",
-    "Angle",
-    "Conj",
-    "Cast",
-    "Cross",
-    "Cumsum",
-    "Cumprod",
-    "ReadVariableOp",
-    "VarHandleOp",
-    "Shape",
-])
-
-_ops_which_dont_need_inputs = set([
-    "Identity",
-    "Softmax",
-    "LogSoftmax",
-    "BiasAdd",
-    "Relu",
-    "Elu",
-    "Selu",
-    "SparseSoftmaxCrossEntropyWithLogits",
-    "Neg",
-    "Inv",
-    "Reciprocal",
-    "Sqrt",
-    "Exp",
-    "Tanh",
-    "Sigmoid",
-    "Real",
-    "Imag",
-    "Conj",
-    "ReadVariableOp",
-    "VarHandleOp",
-    "Shape",
-])
-
-
 # TODO(agarwal): use an automatic mechanism for handling None arguments to
 # gradient functions.
 # Some gradient functions can accept None arguments for gradients. The following
@@ -259,57 +155,25 @@ _grad_fn_accepts_none_for_indices = {
 }
 
 
-def _record_gradient(op_name, inputs, attrs, results, name):
-  """Records gradients for a TensorFlow operation.
-
-  Args:
-    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
-      execute.
-    inputs: A flat list of Tensor object inputs to the operation.
-    attrs: A tuple with alternating string attr names and attr values for this
-      operation.
-    results: The results of the operation (as a flat list).
-    name: Customized name for the operation.
-
-  Returns:
-    A list of maybe-wrapped results. Either Tensors or TensorNodes.
-
-  Raises:
-    An exception on error.
-  """
-  if not tape.could_possibly_record():
-    return
-
-  if op_name in _ops_which_dont_need_outputs:
-    op_outputs = None
-  else:
-    # TODO(apassos) this line creates a weak circular reference where the
-    # backprop function keeps an output alive which in turn keeps the tape entry
-    # alive which keeps the backprop function alive. Figure out how to break
-    # this up without breaking second derivatives of ops like Exp whose
-    # gradients depend only on the outputs.
-    op_outputs = results
-
-  if op_name in _ops_which_dont_need_inputs:
-    op_inputs = None
-  else:
-    op_inputs = inputs
-
-  num_inputs = len(inputs)
+def _get_backward_fn(op_name, attrs, num_inputs, op_inputs, op_outputs):
 
   def grad_fn(*orig_outputs):
-    """Generated gradient function."""
     result = _magic_gradient_function(op_name, attrs, num_inputs,
                                       op_inputs, op_outputs, orig_outputs)
     if _tracing:
-      print("Gradient for", (name if name else op_name), "inputs", op_inputs,
-            "output_grads", orig_outputs, "gradients", result)
+      print("Gradient for", op_name, "inputs", op_inputs, "output_grads",
+            orig_outputs, "gradients", result)
     return nest.flatten(result)
 
-  tape.record_operation(op_name, results, inputs, grad_fn)
-  if _tracing:
-    print("Computed op", (name if name else op_name), "inputs", inputs,
-          "outputs", results)
+  return grad_fn
+
+
+pywrap_tensorflow.TFE_Py_RegisterBackwardFunctionGetter(_get_backward_fn)
+
+
+def _record_gradient(op_name, inputs, attrs, results, name):
+  return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs,
+                                                 results, name)
 
 
 execute.record_gradient = _record_gradient
index 16b7d1a..f9692a8 100644 (file)
@@ -59,6 +59,15 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
 // This function is not thread-safe.
 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e);
 
+// Registers e as the backward_function_getter.
+// The registered function creates a backward function (a function that can
+// return the gradient of the inputs an op given the gradient of it's outputs).
+// The registered function will be passed the following arguments:
+//    op_name, attrs, num_inputs, op_inputs, op_outputs
+//
+// This function is not thread-safe.
+PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e);
+
 // Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
 // `exception` if not nullptr, else using the class registered via
 // TFE_Py_RegisterExceptionClass), and returns -1.
@@ -165,6 +174,11 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
 // directive.
 PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args);
 
+// Record the gradient for a given op.
+PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
+                                PyObject* attrs, PyObject* results,
+                                PyObject* name);
+
 // Returns the set of variables watched by the given tape.
 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
 
index cabbcc4..30e08c8 100644 (file)
@@ -24,6 +24,7 @@ limitations under the License.
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/gtl/compactptrset.h"
 #include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/mutex.h"
@@ -575,6 +576,9 @@ PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
 // Python subclass of Exception that is created to signal fallback.
 PyObject* fallback_exception_class = nullptr;
 
+// Python function that returns a backward_function.
+PyObject* backward_function_getter = nullptr;
+
 tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
 tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
 
@@ -647,6 +651,23 @@ PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
   }
 }
 
+PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e) {
+  if (backward_function_getter != nullptr) {
+    Py_DECREF(backward_function_getter);
+  }
+  if (!PyCallable_Check(e)) {
+    backward_function_getter = nullptr;
+    PyErr_SetString(PyExc_TypeError,
+                    "TFE_Py_RegisterBackwardFunctionGetter: "
+                    "Registered object should be function.");
+    return nullptr;
+  } else {
+    Py_INCREF(e);
+    backward_function_getter = e;
+    Py_RETURN_NONE;
+  }
+}
+
 void RaiseFallbackException(const char* message) {
   if (fallback_exception_class != nullptr) {
     PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message));
@@ -1062,16 +1083,10 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
   return result;
 }
 
-void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
-                                   PyObject* input_tensors,
-                                   PyObject* backward_function) {
-  if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) {
-    return;
-  }
-  std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
-  if (PyErr_Occurred()) {
-    return;
-  }
+namespace {
+void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
+                            const std::vector<tensorflow::int64>& input_ids,
+                            PyObject* backward_function) {
   std::vector<tensorflow::eager::TapeTensor> output_info;
   PyObject* seq = PySequence_Fast(output_tensors,
                                   "expected a sequence of integer tensor ids");
@@ -1110,6 +1125,19 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
         [backward_function]() { Py_DECREF(backward_function); });
   }
 }
+}  // namespace
+
+void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
+                                   PyObject* input_tensors,
+                                   PyObject* backward_function) {
+  if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) {
+    return;
+  }
+  std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
+  if (PyErr_Occurred()) return;
+
+  TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function);
+}
 
 void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
   for (TFE_Py_Tape* tape : SafeTapeSet()) {
@@ -1430,6 +1458,164 @@ bool RaiseIfNotPyList(PyObject* list, const string& attr_name) {
   return true;
 }
 
+bool OpDoesntRequireOutput(const string& op_name) {
+  static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_outputs =
+      new tensorflow::gtl::FlatSet<string>({
+          "Identity",
+          "MatMul",
+          "Conv2DBackpropInput",
+          "Conv2DBackpropFilter",
+          "Conv3D",
+          "Conv3DBackpropInputV2",
+          "AvgPool3D",
+          "AvgPool3DGrad",
+          "MaxPool3D",
+          "MaxPool3DGrad",
+          "MaxPool3DGradGrad",
+          "BiasAdd",
+          "BiasAddV1",
+          "BiasAddGrad",
+          "Relu6",
+          "Softplus",
+          "SoftplusGrad",
+          "Softsign",
+          "ReluGrad",
+          "Conv2D",
+          "DepthwiseConv2dNative",
+          "Dilation2D",
+          "AvgPool",
+          "AvgPoolGrad",
+          "BatchNormWithGlobalNormalization",
+          "L2Loss",
+          "Sum",
+          "Prod",
+          "SegmentSum",
+          "SegmentMean",
+          "SparseSegmentSum",
+          "SparseSegmentMean",
+          "SparseSegmentSqrtN",
+          "SegmentMin",
+          "SegmentMax",
+          "UnsortedSegmentSum",
+          "UnsortedSegmentMax",
+          "Abs",
+          "Neg",
+          "ReciprocalGrad",
+          "Square",
+          "Expm1",
+          "Log",
+          "Log1p",
+          "TanhGrad",
+          "SigmoidGrad",
+          "Sign",
+          "Sin",
+          "Cos",
+          "Tan",
+          "Add",
+          "Sub",
+          "Mul",
+          "Div",
+          "RealDiv",
+          "Maximum",
+          "Minimum",
+          "SquaredDifference",
+          "Select",
+          "SparseMatMul",
+          "BatchMatMul",
+          "Complex",
+          "Real",
+          "Imag",
+          "Angle",
+          "Conj",
+          "Cast",
+          "Cross",
+          "Cumsum",
+          "Cumprod",
+          "ReadVariableOp",
+          "VarHandleOp",
+          "Shape",
+      });
+
+  return ops_that_dont_require_outputs->find(op_name) !=
+         ops_that_dont_require_outputs->end();
+}
+
+bool OpDoesntRequireInput(const string& op_name) {
+  static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_inputs =
+      new tensorflow::gtl::FlatSet<string>({
+          "Identity",
+          "Softmax",
+          "LogSoftmax",
+          "BiasAdd",
+          "Relu",
+          "Elu",
+          "Selu",
+          "SparseSoftmaxCrossEntropyWithLogits",
+          "Neg",
+          "Inv",
+          "Reciprocal",
+          "Sqrt",
+          "Exp",
+          "Tanh",
+          "Sigmoid",
+          "Real",
+          "Imag",
+          "Conj",
+          "ReadVariableOp",
+          "VarHandleOp",
+          "Shape",
+      });
+
+  return ops_that_dont_require_inputs->find(op_name) !=
+         ops_that_dont_require_inputs->end();
+}
+
+PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
+                         PyObject* results, PyObject* name) {
+  std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
+  if (PyErr_Occurred()) return nullptr;
+
+  bool should_record = false;
+  for (TFE_Py_Tape* tape : SafeTapeSet()) {
+    if (tape->tape->ShouldRecord(input_ids)) {
+      should_record = true;
+      break;
+    }
+  }
+
+  if (!should_record) Py_RETURN_NONE;
+
+  string c_op_name = TFE_GetPythonString(op_name);
+  PyObject* op_outputs;
+  if (OpDoesntRequireOutput(c_op_name)) {
+    op_outputs = Py_None;
+  } else {
+    op_outputs = results;
+  }
+
+  PyObject* op_inputs;
+  if (OpDoesntRequireInput(c_op_name)) {
+    op_inputs = Py_None;
+  } else {
+    op_inputs = inputs;
+  }
+
+  PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
+  PyObject* callback_args =
+      Py_BuildValue("OOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs);
+
+  PyObject* backward_function =
+      PyObject_CallObject(backward_function_getter, callback_args);
+  Py_DECREF(callback_args);
+  if (backward_function == nullptr) return nullptr;
+
+  TapeSetRecordOperation(op_name, results, input_ids, backward_function);
+
+  Py_DECREF(backward_function);
+
+  Py_RETURN_NONE;
+}
+
 bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks,
                   const tensorflow::OpDef* op_def, PyObject* args,
                   const std::vector<PyObject*>& flattened_inputs,
@@ -1471,21 +1657,7 @@ bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks,
   });
 
   if (run_gradient_callback) {
-    if (!PyCallable_Check(record_gradient_callback)) {
-      PyErr_SetString(PyExc_TypeError,
-                      Printf("expected a function for "
-                             "record_gradient_callback, got %s instead",
-                             record_gradient_callback->ob_type->tp_name)
-                          .c_str());
-      return false;
-    }
-
-    PyObject* callback_result =
-        PyObject_CallObject(record_gradient_callback, callback_args);
-    if (!callback_result) {
-      return false;
-    }
-    Py_DECREF(callback_result);
+    RecordGradient(op_name, inputs, attrs, flattened_result, name);
   }
 
   if (run_post_exec_callbacks) {
@@ -1796,3 +1968,13 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
   Py_DECREF(flat_result);
   return result;
 }
+
+PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
+                                PyObject* attrs, PyObject* results,
+                                PyObject* name) {
+  if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
+    Py_RETURN_NONE;
+  }
+
+  return RecordGradient(op_name, inputs, attrs, results, name);
+}
index 50f481d..7ab0db5 100644 (file)
@@ -29,9 +29,11 @@ limitations under the License.
 %rename("%s") TFE_OpNameGetAttrType;
 %rename("%s") TFE_Py_InitEagerTensor;
 %rename("%s") TFE_Py_RegisterExceptionClass;
+%rename("%s") TFE_Py_RegisterBackwardFunctionGetter;
 %rename("%s") TFE_Py_RegisterFallbackExceptionClass;
 %rename("%s") TFE_Py_Execute;
 %rename("%s") TFE_Py_FastPathExecute;
+%rename("%s") TFE_Py_RecordGradient;
 %rename("%s") TFE_Py_UID;
 %rename("%s") TFE_Py_TapeSetNew;
 %rename("%s") TFE_Py_TapeSetRemove;