Adding Operation._control_outputs
authorAlexandre Passos <apassos@google.com>
Wed, 4 Apr 2018 22:42:14 +0000 (15:42 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 22:44:55 +0000 (15:44 -0700)
PiperOrigin-RevId: 191659944

tensorflow/python/client/tf_session.i
tensorflow/python/client/tf_session_helper.cc
tensorflow/python/client/tf_session_helper.h
tensorflow/python/framework/ops.py
tensorflow/python/framework/ops_test.py

index 0c18d97..b82182d 100644 (file)
@@ -157,6 +157,25 @@ tensorflow::ImportNumpy();
   }
 }
 
+// We use TF_OperationGetControlOutputs_wrapper instead of
+// TF_OperationGetControlOutputs
+%ignore TF_OperationGetControlOutputs;
+%unignore TF_OperationGetControlOutputs_wrapper;
+// See comment for "%noexception TF_SessionRun_wrapper;"
+%noexception TF_OperationGetControlOutputs_wrapper;
+
+// Build a Python list of TF_Operation* and return it.
+%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlOutputs_wrapper {
+  $result = PyList_New($1.size());
+  if (!$result) {
+    SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
+  }
+
+  for (size_t i = 0; i < $1.size(); ++i) {
+    PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i]));
+  }
+}
+
 %ignore TF_OperationOutputConsumers;
 %unignore TF_OperationOutputConsumers_wrapper;
 // See comment for "%noexception TF_SessionRun_wrapper;"
index ca57abd..b48d758 100644 (file)
@@ -550,6 +550,15 @@ std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
   return control_inputs;
 }
 
+std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
+    TF_Operation* oper) {
+  std::vector<TF_Operation*> control_outputs(
+      TF_OperationNumControlOutputs(oper));
+  TF_OperationGetControlOutputs(oper, control_outputs.data(),
+                                control_outputs.size());
+  return control_outputs;
+}
+
 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
     TF_Output oper_out) {
   int num_consumers = TF_OperationOutputNumConsumers(oper_out);
index 5416d41..d2b4abc 100644 (file)
@@ -190,6 +190,10 @@ std::vector<TF_Output> GetOperationInputs(TF_Operation* oper);
 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
     TF_Operation* oper);
 
+// Retrieves the control outputs of this operation.
+std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
+    TF_Operation* oper);
+
 // Retrieves the op names of the consumers of `oper_out`. The returned strings
 // have the lifetime of the underlying TF_Graph.
 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
index 0215501..2d55f98 100644 (file)
@@ -2114,6 +2114,30 @@ class Operation(object):
       return self._control_inputs_val
 
   @property
+  def _control_outputs(self):
+    """The `Operation` objects which have a control dependency on this op.
+
+    Before any of the ops in self._control_outputs can execute tensorflow will
+    ensure self has finished executing.
+
+    Returns:
+      A list of `Operation` objects.
+
+    """
+    if self._c_op:
+      control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
+      # pylint: disable=protected-access
+      return [
+          self.graph._get_operation_by_name_unsafe(
+              c_api.TF_OperationName(c_op)) for c_op in control_c_ops
+      ]
+      # pylint: enable=protected-access
+    else:
+      # TODO(apassos) this should be less inefficient.
+      return [o for o in self._graph.get_operations()
+              if self in o.control_inputs]
+
+  @property
   def _control_inputs(self):
     logging.warning("Operation._control_inputs is private, use "
                     "Operation.control_inputs instead. "
index aa51391..58bead9 100644 (file)
@@ -473,6 +473,7 @@ class OperationTest(test_util.TensorFlowTestCase):
     self.assertEqual(z.control_inputs, [x, x])
     z._add_control_inputs([x, y, y])  # pylint: disable=protected-access
     self.assertEqual(z.control_inputs, [x, x, x, y, y])
+    self.assertEqual(x._control_outputs, [z])
 
   def testAddControlInputC(self):
     # The C API dedups redundant control edges, pure Python does not
@@ -487,6 +488,7 @@ class OperationTest(test_util.TensorFlowTestCase):
     self.assertEqual(z.control_inputs, [x])
     z._add_control_inputs([x, y, y])  # pylint: disable=protected-access
     self.assertEqual(z.control_inputs, [x, y])
+    self.assertEqual(x._control_outputs, [z])
 
   def testRemoveAllControlInputs(self):
     a = constant_op.constant(1)