From 32e39947c80ad410042a5aea266f197e9ecd289d Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Wed, 21 Feb 2018 14:42:11 -0800 Subject: [PATCH] Record gradient in C PiperOrigin-RevId: 186522240 --- tensorflow/python/eager/backprop.py | 160 ++------------------- tensorflow/python/eager/pywrap_tfe.h | 14 ++ tensorflow/python/eager/pywrap_tfe_src.cc | 232 ++++++++++++++++++++++++++---- tensorflow/python/pywrap_tfe.i | 2 + 4 files changed, 235 insertions(+), 173 deletions(-) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index d8e13d7..5505661 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -137,110 +137,6 @@ _gradient_functions_lock = threading.Lock() _tracing = False -# TODO(apassos) replace this with a mechanism which can happen at the op -# gradient function registration site, to be less error-prone -# TODO(apassos) add ops other than those in nn_grad and math_grad -_ops_which_dont_need_outputs = set([ - "Identity", - "MatMul", - "Conv2DBackpropInput", - "Conv2DBackpropFilter", - "Conv3D", - "Conv3DBackpropInputV2", - "AvgPool3D", - "AvgPool3DGrad", - "MaxPool3D", - "MaxPool3DGrad", - "MaxPool3DGradGrad", - "BiasAdd", - "BiasAddV1", - "BiasAddGrad", - "Relu6", - "Softplus", - "SoftplusGrad", - "Softsign", - "ReluGrad", - "Conv2D", - "DepthwiseConv2dNative", - "Dilation2D", - "AvgPool", - "AvgPoolGrad", - "BatchNormWithGlobalNormalization", - "L2Loss", - "Sum", - "Prod", - "SegmentSum", - "SegmentMean", - "SparseSegmentSum", - "SparseSegmentMean", - "SparseSegmentSqrtN", - "SegmentMin", - "SegmentMax", - "UnsortedSegmentSum", - "UnsortedSegmentMax", - "Abs", - "Neg", - "ReciprocalGrad", - "Square", - "Expm1", - "Log", - "Log1p", - "TanhGrad", - "SigmoidGrad", - "Sign", - "Sin", - "Cos", - "Tan", - "Add", - "Sub", - "Mul", - "Div", - "RealDiv", - "Maximum", - "Minimum", - "SquaredDifference", - "Select", - "SparseMatMul", - "BatchMatMul", - "Complex", - "Real", - "Imag", - "Angle", - "Conj", - "Cast", - "Cross", - "Cumsum", - "Cumprod", - "ReadVariableOp", - "VarHandleOp", - "Shape", -]) - -_ops_which_dont_need_inputs = set([ - "Identity", - "Softmax", - "LogSoftmax", - "BiasAdd", - "Relu", - "Elu", - "Selu", - "SparseSoftmaxCrossEntropyWithLogits", - "Neg", - "Inv", - "Reciprocal", - "Sqrt", - "Exp", - "Tanh", - "Sigmoid", - "Real", - "Imag", - "Conj", - "ReadVariableOp", - "VarHandleOp", - "Shape", -]) - - # TODO(agarwal): use an automatic mechanism for handling None arguments to # gradient functions. # Some gradient functions can accept None arguments for gradients. The following @@ -259,57 +155,25 @@ _grad_fn_accepts_none_for_indices = { } -def _record_gradient(op_name, inputs, attrs, results, name): - """Records gradients for a TensorFlow operation. - - Args: - op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to - execute. - inputs: A flat list of Tensor object inputs to the operation. - attrs: A tuple with alternating string attr names and attr values for this - operation. - results: The results of the operation (as a flat list). - name: Customized name for the operation. - - Returns: - A list of maybe-wrapped results. Either Tensors or TensorNodes. - - Raises: - An exception on error. - """ - if not tape.could_possibly_record(): - return - - if op_name in _ops_which_dont_need_outputs: - op_outputs = None - else: - # TODO(apassos) this line creates a weak circular reference where the - # backprop function keeps an output alive which in turn keeps the tape entry - # alive which keeps the backprop function alive. Figure out how to break - # this up without breaking second derivatives of ops like Exp whose - # gradients depend only on the outputs. - op_outputs = results - - if op_name in _ops_which_dont_need_inputs: - op_inputs = None - else: - op_inputs = inputs - - num_inputs = len(inputs) +def _get_backward_fn(op_name, attrs, num_inputs, op_inputs, op_outputs): def grad_fn(*orig_outputs): - """Generated gradient function.""" result = _magic_gradient_function(op_name, attrs, num_inputs, op_inputs, op_outputs, orig_outputs) if _tracing: - print("Gradient for", (name if name else op_name), "inputs", op_inputs, - "output_grads", orig_outputs, "gradients", result) + print("Gradient for", op_name, "inputs", op_inputs, "output_grads", + orig_outputs, "gradients", result) return nest.flatten(result) - tape.record_operation(op_name, results, inputs, grad_fn) - if _tracing: - print("Computed op", (name if name else op_name), "inputs", inputs, - "outputs", results) + return grad_fn + + +pywrap_tensorflow.TFE_Py_RegisterBackwardFunctionGetter(_get_backward_fn) + + +def _record_gradient(op_name, inputs, attrs, results, name): + return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs, + results, name) execute.record_gradient = _record_gradient diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 16b7d1a..f9692a8 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -59,6 +59,15 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e); // This function is not thread-safe. PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e); +// Registers e as the backward_function_getter. +// The registered function creates a backward function (a function that can +// return the gradient of the inputs an op given the gradient of it's outputs). +// The registered function will be passed the following arguments: +// op_name, attrs, num_inputs, op_inputs, op_outputs +// +// This function is not thread-safe. +PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e); + // Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using // `exception` if not nullptr, else using the class registered via // TFE_Py_RegisterExceptionClass), and returns -1. @@ -165,6 +174,11 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, // directive. PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args); +// Record the gradient for a given op. +PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, + PyObject* attrs, PyObject* results, + PyObject* name); + // Returns the set of variables watched by the given tape. PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index cabbcc4..30e08c8 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/compactptrset.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/mutex.h" @@ -575,6 +576,9 @@ PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr; // Python subclass of Exception that is created to signal fallback. PyObject* fallback_exception_class = nullptr; +// Python function that returns a backward_function. +PyObject* backward_function_getter = nullptr; + tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED); tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0; @@ -647,6 +651,23 @@ PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { } } +PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e) { + if (backward_function_getter != nullptr) { + Py_DECREF(backward_function_getter); + } + if (!PyCallable_Check(e)) { + backward_function_getter = nullptr; + PyErr_SetString(PyExc_TypeError, + "TFE_Py_RegisterBackwardFunctionGetter: " + "Registered object should be function."); + return nullptr; + } else { + Py_INCREF(e); + backward_function_getter = e; + Py_RETURN_NONE; + } +} + void RaiseFallbackException(const char* message) { if (fallback_exception_class != nullptr) { PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message)); @@ -1062,16 +1083,10 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { return result; } -void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, - PyObject* input_tensors, - PyObject* backward_function) { - if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) { - return; - } - std::vector input_ids = MakeTensorIDList(input_tensors); - if (PyErr_Occurred()) { - return; - } +namespace { +void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, + const std::vector& input_ids, + PyObject* backward_function) { std::vector output_info; PyObject* seq = PySequence_Fast(output_tensors, "expected a sequence of integer tensor ids"); @@ -1110,6 +1125,19 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, [backward_function]() { Py_DECREF(backward_function); }); } } +} // namespace + +void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, + PyObject* input_tensors, + PyObject* backward_function) { + if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) { + return; + } + std::vector input_ids = MakeTensorIDList(input_tensors); + if (PyErr_Occurred()) return; + + TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function); +} void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { for (TFE_Py_Tape* tape : SafeTapeSet()) { @@ -1430,6 +1458,164 @@ bool RaiseIfNotPyList(PyObject* list, const string& attr_name) { return true; } +bool OpDoesntRequireOutput(const string& op_name) { + static tensorflow::gtl::FlatSet* ops_that_dont_require_outputs = + new tensorflow::gtl::FlatSet({ + "Identity", + "MatMul", + "Conv2DBackpropInput", + "Conv2DBackpropFilter", + "Conv3D", + "Conv3DBackpropInputV2", + "AvgPool3D", + "AvgPool3DGrad", + "MaxPool3D", + "MaxPool3DGrad", + "MaxPool3DGradGrad", + "BiasAdd", + "BiasAddV1", + "BiasAddGrad", + "Relu6", + "Softplus", + "SoftplusGrad", + "Softsign", + "ReluGrad", + "Conv2D", + "DepthwiseConv2dNative", + "Dilation2D", + "AvgPool", + "AvgPoolGrad", + "BatchNormWithGlobalNormalization", + "L2Loss", + "Sum", + "Prod", + "SegmentSum", + "SegmentMean", + "SparseSegmentSum", + "SparseSegmentMean", + "SparseSegmentSqrtN", + "SegmentMin", + "SegmentMax", + "UnsortedSegmentSum", + "UnsortedSegmentMax", + "Abs", + "Neg", + "ReciprocalGrad", + "Square", + "Expm1", + "Log", + "Log1p", + "TanhGrad", + "SigmoidGrad", + "Sign", + "Sin", + "Cos", + "Tan", + "Add", + "Sub", + "Mul", + "Div", + "RealDiv", + "Maximum", + "Minimum", + "SquaredDifference", + "Select", + "SparseMatMul", + "BatchMatMul", + "Complex", + "Real", + "Imag", + "Angle", + "Conj", + "Cast", + "Cross", + "Cumsum", + "Cumprod", + "ReadVariableOp", + "VarHandleOp", + "Shape", + }); + + return ops_that_dont_require_outputs->find(op_name) != + ops_that_dont_require_outputs->end(); +} + +bool OpDoesntRequireInput(const string& op_name) { + static tensorflow::gtl::FlatSet* ops_that_dont_require_inputs = + new tensorflow::gtl::FlatSet({ + "Identity", + "Softmax", + "LogSoftmax", + "BiasAdd", + "Relu", + "Elu", + "Selu", + "SparseSoftmaxCrossEntropyWithLogits", + "Neg", + "Inv", + "Reciprocal", + "Sqrt", + "Exp", + "Tanh", + "Sigmoid", + "Real", + "Imag", + "Conj", + "ReadVariableOp", + "VarHandleOp", + "Shape", + }); + + return ops_that_dont_require_inputs->find(op_name) != + ops_that_dont_require_inputs->end(); +} + +PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, + PyObject* results, PyObject* name) { + std::vector input_ids = MakeTensorIDList(inputs); + if (PyErr_Occurred()) return nullptr; + + bool should_record = false; + for (TFE_Py_Tape* tape : SafeTapeSet()) { + if (tape->tape->ShouldRecord(input_ids)) { + should_record = true; + break; + } + } + + if (!should_record) Py_RETURN_NONE; + + string c_op_name = TFE_GetPythonString(op_name); + PyObject* op_outputs; + if (OpDoesntRequireOutput(c_op_name)) { + op_outputs = Py_None; + } else { + op_outputs = results; + } + + PyObject* op_inputs; + if (OpDoesntRequireInput(c_op_name)) { + op_inputs = Py_None; + } else { + op_inputs = inputs; + } + + PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs)); + PyObject* callback_args = + Py_BuildValue("OOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs); + + PyObject* backward_function = + PyObject_CallObject(backward_function_getter, callback_args); + Py_DECREF(callback_args); + if (backward_function == nullptr) return nullptr; + + TapeSetRecordOperation(op_name, results, input_ids, backward_function); + + Py_DECREF(backward_function); + + Py_RETURN_NONE; +} + bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks, const tensorflow::OpDef* op_def, PyObject* args, const std::vector& flattened_inputs, @@ -1471,21 +1657,7 @@ bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks, }); if (run_gradient_callback) { - if (!PyCallable_Check(record_gradient_callback)) { - PyErr_SetString(PyExc_TypeError, - Printf("expected a function for " - "record_gradient_callback, got %s instead", - record_gradient_callback->ob_type->tp_name) - .c_str()); - return false; - } - - PyObject* callback_result = - PyObject_CallObject(record_gradient_callback, callback_args); - if (!callback_result) { - return false; - } - Py_DECREF(callback_result); + RecordGradient(op_name, inputs, attrs, flattened_result, name); } if (run_post_exec_callbacks) { @@ -1796,3 +1968,13 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { Py_DECREF(flat_result); return result; } + +PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, + PyObject* attrs, PyObject* results, + PyObject* name) { + if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) { + Py_RETURN_NONE; + } + + return RecordGradient(op_name, inputs, attrs, results, name); +} diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 50f481d..7ab0db5 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -29,9 +29,11 @@ limitations under the License. %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; %rename("%s") TFE_Py_RegisterExceptionClass; +%rename("%s") TFE_Py_RegisterBackwardFunctionGetter; %rename("%s") TFE_Py_RegisterFallbackExceptionClass; %rename("%s") TFE_Py_Execute; %rename("%s") TFE_Py_FastPathExecute; +%rename("%s") TFE_Py_RecordGradient; %rename("%s") TFE_Py_UID; %rename("%s") TFE_Py_TapeSetNew; %rename("%s") TFE_Py_TapeSetRemove; -- 2.7.4