}
}
+// We use TF_OperationGetControlOutputs_wrapper instead of
+// TF_OperationGetControlOutputs
+%ignore TF_OperationGetControlOutputs;
+%unignore TF_OperationGetControlOutputs_wrapper;
+// See comment for "%noexception TF_SessionRun_wrapper;"
+%noexception TF_OperationGetControlOutputs_wrapper;
+
+// Build a Python list of TF_Operation* and return it.
+%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlOutputs_wrapper {
+ $result = PyList_New($1.size());
+ if (!$result) {
+ SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
+ }
+
+ for (size_t i = 0; i < $1.size(); ++i) {
+ PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i]));
+ }
+}
+
%ignore TF_OperationOutputConsumers;
%unignore TF_OperationOutputConsumers_wrapper;
// See comment for "%noexception TF_SessionRun_wrapper;"
return control_inputs;
}
+std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
+ TF_Operation* oper) {
+ std::vector<TF_Operation*> control_outputs(
+ TF_OperationNumControlOutputs(oper));
+ TF_OperationGetControlOutputs(oper, control_outputs.data(),
+ control_outputs.size());
+ return control_outputs;
+}
+
std::vector<const char*> TF_OperationOutputConsumers_wrapper(
TF_Output oper_out) {
int num_consumers = TF_OperationOutputNumConsumers(oper_out);
std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
TF_Operation* oper);
+// Retrieves the control outputs of this operation.
+std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
+ TF_Operation* oper);
+
// Retrieves the op names of the consumers of `oper_out`. The returned strings
// have the lifetime of the underlying TF_Graph.
std::vector<const char*> TF_OperationOutputConsumers_wrapper(
return self._control_inputs_val
@property
+ def _control_outputs(self):
+ """The `Operation` objects which have a control dependency on this op.
+
+ Before any of the ops in self._control_outputs can execute tensorflow will
+ ensure self has finished executing.
+
+ Returns:
+ A list of `Operation` objects.
+
+ """
+ if self._c_op:
+ control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
+ # pylint: disable=protected-access
+ return [
+ self.graph._get_operation_by_name_unsafe(
+ c_api.TF_OperationName(c_op)) for c_op in control_c_ops
+ ]
+ # pylint: enable=protected-access
+ else:
+ # TODO(apassos) this should be less inefficient.
+ return [o for o in self._graph.get_operations()
+ if self in o.control_inputs]
+
+ @property
def _control_inputs(self):
logging.warning("Operation._control_inputs is private, use "
"Operation.control_inputs instead. "
self.assertEqual(z.control_inputs, [x, x])
z._add_control_inputs([x, y, y]) # pylint: disable=protected-access
self.assertEqual(z.control_inputs, [x, x, x, y, y])
+ self.assertEqual(x._control_outputs, [z])
def testAddControlInputC(self):
# The C API dedups redundant control edges, pure Python does not
self.assertEqual(z.control_inputs, [x])
z._add_control_inputs([x, y, y]) # pylint: disable=protected-access
self.assertEqual(z.control_inputs, [x, y])
+ self.assertEqual(x._control_outputs, [z])
def testRemoveAllControlInputs(self):
a = constant_op.constant(1)