From 4cfb393b087dc50c150054531186ccb71882e2d0 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Wed, 4 Apr 2018 15:42:14 -0700 Subject: [PATCH] Adding Operation._control_outputs PiperOrigin-RevId: 191659944 --- tensorflow/python/client/tf_session.i | 19 +++++++++++++++++++ tensorflow/python/client/tf_session_helper.cc | 9 +++++++++ tensorflow/python/client/tf_session_helper.h | 4 ++++ tensorflow/python/framework/ops.py | 24 ++++++++++++++++++++++++ tensorflow/python/framework/ops_test.py | 2 ++ 5 files changed, 58 insertions(+) diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 0c18d97..b82182d 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -157,6 +157,25 @@ tensorflow::ImportNumpy(); } } +// We use TF_OperationGetControlOutputs_wrapper instead of +// TF_OperationGetControlOutputs +%ignore TF_OperationGetControlOutputs; +%unignore TF_OperationGetControlOutputs_wrapper; +// See comment for "%noexception TF_SessionRun_wrapper;" +%noexception TF_OperationGetControlOutputs_wrapper; + +// Build a Python list of TF_Operation* and return it. +%typemap(out) std::vector tensorflow::TF_OperationGetControlOutputs_wrapper { + $result = PyList_New($1.size()); + if (!$result) { + SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); + } + + for (size_t i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i])); + } +} + %ignore TF_OperationOutputConsumers; %unignore TF_OperationOutputConsumers_wrapper; // See comment for "%noexception TF_SessionRun_wrapper;" diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index ca57abd..b48d758 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -550,6 +550,15 @@ std::vector TF_OperationGetControlInputs_wrapper( return control_inputs; } +std::vector TF_OperationGetControlOutputs_wrapper( + TF_Operation* oper) { + std::vector control_outputs( + TF_OperationNumControlOutputs(oper)); + TF_OperationGetControlOutputs(oper, control_outputs.data(), + control_outputs.size()); + return control_outputs; +} + std::vector TF_OperationOutputConsumers_wrapper( TF_Output oper_out) { int num_consumers = TF_OperationOutputNumConsumers(oper_out); diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 5416d41..d2b4abc 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -190,6 +190,10 @@ std::vector GetOperationInputs(TF_Operation* oper); std::vector TF_OperationGetControlInputs_wrapper( TF_Operation* oper); +// Retrieves the control outputs of this operation. +std::vector TF_OperationGetControlOutputs_wrapper( + TF_Operation* oper); + // Retrieves the op names of the consumers of `oper_out`. The returned strings // have the lifetime of the underlying TF_Graph. std::vector TF_OperationOutputConsumers_wrapper( diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 0215501..2d55f98 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2114,6 +2114,30 @@ class Operation(object): return self._control_inputs_val @property + def _control_outputs(self): + """The `Operation` objects which have a control dependency on this op. + + Before any of the ops in self._control_outputs can execute tensorflow will + ensure self has finished executing. + + Returns: + A list of `Operation` objects. + + """ + if self._c_op: + control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op) + # pylint: disable=protected-access + return [ + self.graph._get_operation_by_name_unsafe( + c_api.TF_OperationName(c_op)) for c_op in control_c_ops + ] + # pylint: enable=protected-access + else: + # TODO(apassos) this should be less inefficient. + return [o for o in self._graph.get_operations() + if self in o.control_inputs] + + @property def _control_inputs(self): logging.warning("Operation._control_inputs is private, use " "Operation.control_inputs instead. " diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index aa51391..58bead9 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -473,6 +473,7 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertEqual(z.control_inputs, [x, x]) z._add_control_inputs([x, y, y]) # pylint: disable=protected-access self.assertEqual(z.control_inputs, [x, x, x, y, y]) + self.assertEqual(x._control_outputs, [z]) def testAddControlInputC(self): # The C API dedups redundant control edges, pure Python does not @@ -487,6 +488,7 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertEqual(z.control_inputs, [x]) z._add_control_inputs([x, y, y]) # pylint: disable=protected-access self.assertEqual(z.control_inputs, [x, y]) + self.assertEqual(x._control_outputs, [z]) def testRemoveAllControlInputs(self): a = constant_op.constant(1) -- 2.7.4