From: Skye Wanderman-Milne Date: Tue, 23 Jan 2018 22:06:13 +0000 (-0800) Subject: C API: don't return source/sink control edges. X-Git-Tag: v1.6.0-rc0~105^2~1^2~41 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d50f68df6d1e4e41f4486ee71626454f6bd3ffe4;p=platform%2Fupstream%2Ftensorflow.git C API: don't return source/sink control edges. PiperOrigin-RevId: 182989157 --- diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 6fc75a98f1..92c9adbec7 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -1469,7 +1469,13 @@ int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, } int TF_OperationNumControlInputs(TF_Operation* oper) { - return oper->node.in_edges().size() - oper->node.num_inputs(); + int count = 0; + for (const auto* edge : oper->node.in_edges()) { + if (edge->IsControlEdge() && !edge->src()->IsSource()) { + ++count; + } + } + return count; } int TF_OperationGetControlInputs(TF_Operation* oper, @@ -1477,7 +1483,7 @@ int TF_OperationGetControlInputs(TF_Operation* oper, int max_control_inputs) { int count = 0; for (const auto* edge : oper->node.in_edges()) { - if (edge->IsControlEdge()) { + if (edge->IsControlEdge() && !edge->src()->IsSource()) { if (count < max_control_inputs) { control_inputs[count] = ToOperation(edge->src()); } @@ -1490,7 +1496,7 @@ int TF_OperationGetControlInputs(TF_Operation* oper, int TF_OperationNumControlOutputs(TF_Operation* oper) { int count = 0; for (const auto* edge : oper->node.out_edges()) { - if (edge->IsControlEdge()) { + if (edge->IsControlEdge() && !edge->dst()->IsSink()) { ++count; } } @@ -1502,7 +1508,7 @@ int TF_OperationGetControlOutputs(TF_Operation* oper, int max_control_outputs) { int count = 0; for (const auto* edge : oper->node.out_edges()) { - if (edge->IsControlEdge()) { + if (edge->IsControlEdge() && !edge->dst()->IsSink()) { if (count < max_control_outputs) { control_outputs[count] = ToOperation(edge->dst()); } diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index df697e16d3..01954eb235 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -575,7 +575,7 @@ TEST(CAPI, ImportGraphDef) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); - // Create a graph with two nodes: x and 3 + // Create a simple graph. Placeholder(graph, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); @@ -586,7 +586,7 @@ TEST(CAPI, ImportGraphDef) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); - // Export to a GraphDef + // Export to a GraphDef. TF_Buffer* graph_def = TF_NewBuffer(); TF_GraphToGraphDef(graph, graph_def, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); @@ -606,6 +606,31 @@ TEST(CAPI, ImportGraphDef) { ASSERT_TRUE(feed != nullptr); ASSERT_TRUE(neg != nullptr); + // Test basic structure of the imported graph. + EXPECT_EQ(0, TF_OperationNumInputs(scalar)); + EXPECT_EQ(0, TF_OperationNumInputs(feed)); + ASSERT_EQ(1, TF_OperationNumInputs(neg)); + TF_Output neg_input = TF_OperationInput({neg, 0}); + EXPECT_EQ(scalar, neg_input.oper); + EXPECT_EQ(0, neg_input.index); + + // Test that we can't see control edges involving the source and sink nodes. + TF_Operation* control_ops[100]; + EXPECT_EQ(0, TF_OperationNumControlInputs(scalar)); + EXPECT_EQ(0, TF_OperationGetControlInputs(scalar, control_ops, 100)); + EXPECT_EQ(0, TF_OperationNumControlOutputs(scalar)); + EXPECT_EQ(0, TF_OperationGetControlOutputs(scalar, control_ops, 100)); + + EXPECT_EQ(0, TF_OperationNumControlInputs(feed)); + EXPECT_EQ(0, TF_OperationGetControlInputs(feed, control_ops, 100)); + EXPECT_EQ(0, TF_OperationNumControlOutputs(feed)); + EXPECT_EQ(0, TF_OperationGetControlOutputs(feed, control_ops, 100)); + + EXPECT_EQ(0, TF_OperationNumControlInputs(neg)); + EXPECT_EQ(0, TF_OperationGetControlInputs(neg, control_ops, 100)); + EXPECT_EQ(0, TF_OperationNumControlOutputs(neg)); + EXPECT_EQ(0, TF_OperationGetControlOutputs(neg, control_ops, 100)); + // Import it again, with an input mapping, return outputs, and a return // operation, into the same graph. TF_DeleteImportGraphDefOptions(opts); @@ -629,7 +654,7 @@ TEST(CAPI, ImportGraphDef) { ASSERT_TRUE(neg2 != nullptr); // Check input mapping - TF_Output neg_input = TF_OperationInput({neg, 0}); + neg_input = TF_OperationInput({neg, 0}); EXPECT_EQ(scalar, neg_input.oper); EXPECT_EQ(0, neg_input.index);