C API: don't return source/sink control edges.
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 23 Jan 2018 22:06:13 +0000 (14:06 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 23 Jan 2018 22:09:50 +0000 (14:09 -0800)
PiperOrigin-RevId: 182989157

tensorflow/c/c_api.cc
tensorflow/c/c_api_test.cc

index 6fc75a98f1e05c3971cb4546bd16f015c25b6709..92c9adbec7dd68fef418b953369948843fab2c75 100644 (file)
@@ -1469,7 +1469,13 @@ int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
 }
 
 int TF_OperationNumControlInputs(TF_Operation* oper) {
-  return oper->node.in_edges().size() - oper->node.num_inputs();
+  int count = 0;
+  for (const auto* edge : oper->node.in_edges()) {
+    if (edge->IsControlEdge() && !edge->src()->IsSource()) {
+      ++count;
+    }
+  }
+  return count;
 }
 
 int TF_OperationGetControlInputs(TF_Operation* oper,
@@ -1477,7 +1483,7 @@ int TF_OperationGetControlInputs(TF_Operation* oper,
                                  int max_control_inputs) {
   int count = 0;
   for (const auto* edge : oper->node.in_edges()) {
-    if (edge->IsControlEdge()) {
+    if (edge->IsControlEdge() && !edge->src()->IsSource()) {
       if (count < max_control_inputs) {
         control_inputs[count] = ToOperation(edge->src());
       }
@@ -1490,7 +1496,7 @@ int TF_OperationGetControlInputs(TF_Operation* oper,
 int TF_OperationNumControlOutputs(TF_Operation* oper) {
   int count = 0;
   for (const auto* edge : oper->node.out_edges()) {
-    if (edge->IsControlEdge()) {
+    if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
       ++count;
     }
   }
@@ -1502,7 +1508,7 @@ int TF_OperationGetControlOutputs(TF_Operation* oper,
                                   int max_control_outputs) {
   int count = 0;
   for (const auto* edge : oper->node.out_edges()) {
-    if (edge->IsControlEdge()) {
+    if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
       if (count < max_control_outputs) {
         control_outputs[count] = ToOperation(edge->dst());
       }
index df697e16d3d3fcaac66f967c0d3938450f0b0be6..01954eb235f1a93d943c2ec7ea4c5ca44785d402 100644 (file)
@@ -575,7 +575,7 @@ TEST(CAPI, ImportGraphDef) {
   TF_Status* s = TF_NewStatus();
   TF_Graph* graph = TF_NewGraph();
 
-  // Create a graph with two nodes: x and 3
+  // Create a simple graph.
   Placeholder(graph, s);
   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
   ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
@@ -586,7 +586,7 @@ TEST(CAPI, ImportGraphDef) {
   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
   ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
 
-  // Export to a GraphDef
+  // Export to a GraphDef.
   TF_Buffer* graph_def = TF_NewBuffer();
   TF_GraphToGraphDef(graph, graph_def, s);
   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
@@ -606,6 +606,31 @@ TEST(CAPI, ImportGraphDef) {
   ASSERT_TRUE(feed != nullptr);
   ASSERT_TRUE(neg != nullptr);
 
+  // Test basic structure of the imported graph.
+  EXPECT_EQ(0, TF_OperationNumInputs(scalar));
+  EXPECT_EQ(0, TF_OperationNumInputs(feed));
+  ASSERT_EQ(1, TF_OperationNumInputs(neg));
+  TF_Output neg_input = TF_OperationInput({neg, 0});
+  EXPECT_EQ(scalar, neg_input.oper);
+  EXPECT_EQ(0, neg_input.index);
+
+  // Test that we can't see control edges involving the source and sink nodes.
+  TF_Operation* control_ops[100];
+  EXPECT_EQ(0, TF_OperationNumControlInputs(scalar));
+  EXPECT_EQ(0, TF_OperationGetControlInputs(scalar, control_ops, 100));
+  EXPECT_EQ(0, TF_OperationNumControlOutputs(scalar));
+  EXPECT_EQ(0, TF_OperationGetControlOutputs(scalar, control_ops, 100));
+
+  EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
+  EXPECT_EQ(0, TF_OperationGetControlInputs(feed, control_ops, 100));
+  EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
+  EXPECT_EQ(0, TF_OperationGetControlOutputs(feed, control_ops, 100));
+
+  EXPECT_EQ(0, TF_OperationNumControlInputs(neg));
+  EXPECT_EQ(0, TF_OperationGetControlInputs(neg, control_ops, 100));
+  EXPECT_EQ(0, TF_OperationNumControlOutputs(neg));
+  EXPECT_EQ(0, TF_OperationGetControlOutputs(neg, control_ops, 100));
+
   // Import it again, with an input mapping, return outputs, and a return
   // operation, into the same graph.
   TF_DeleteImportGraphDefOptions(opts);
@@ -629,7 +654,7 @@ TEST(CAPI, ImportGraphDef) {
   ASSERT_TRUE(neg2 != nullptr);
 
   // Check input mapping
-  TF_Output neg_input = TF_OperationInput({neg, 0});
+  neg_input = TF_OperationInput({neg, 0});
   EXPECT_EQ(scalar, neg_input.oper);
   EXPECT_EQ(0, neg_input.index);