}
int TF_OperationNumControlInputs(TF_Operation* oper) {
- return oper->node.in_edges().size() - oper->node.num_inputs();
+ int count = 0;
+ for (const auto* edge : oper->node.in_edges()) {
+ if (edge->IsControlEdge() && !edge->src()->IsSource()) {
+ ++count;
+ }
+ }
+ return count;
}
int TF_OperationGetControlInputs(TF_Operation* oper,
int max_control_inputs) {
int count = 0;
for (const auto* edge : oper->node.in_edges()) {
- if (edge->IsControlEdge()) {
+ if (edge->IsControlEdge() && !edge->src()->IsSource()) {
if (count < max_control_inputs) {
control_inputs[count] = ToOperation(edge->src());
}
int TF_OperationNumControlOutputs(TF_Operation* oper) {
int count = 0;
for (const auto* edge : oper->node.out_edges()) {
- if (edge->IsControlEdge()) {
+ if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
++count;
}
}
int max_control_outputs) {
int count = 0;
for (const auto* edge : oper->node.out_edges()) {
- if (edge->IsControlEdge()) {
+ if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
if (count < max_control_outputs) {
control_outputs[count] = ToOperation(edge->dst());
}
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
- // Create a graph with two nodes: x and 3
+ // Create a simple graph.
Placeholder(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
- // Export to a GraphDef
+ // Export to a GraphDef.
TF_Buffer* graph_def = TF_NewBuffer();
TF_GraphToGraphDef(graph, graph_def, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(feed != nullptr);
ASSERT_TRUE(neg != nullptr);
+ // Test basic structure of the imported graph.
+ EXPECT_EQ(0, TF_OperationNumInputs(scalar));
+ EXPECT_EQ(0, TF_OperationNumInputs(feed));
+ ASSERT_EQ(1, TF_OperationNumInputs(neg));
+ TF_Output neg_input = TF_OperationInput({neg, 0});
+ EXPECT_EQ(scalar, neg_input.oper);
+ EXPECT_EQ(0, neg_input.index);
+
+ // Test that we can't see control edges involving the source and sink nodes.
+ TF_Operation* control_ops[100];
+ EXPECT_EQ(0, TF_OperationNumControlInputs(scalar));
+ EXPECT_EQ(0, TF_OperationGetControlInputs(scalar, control_ops, 100));
+ EXPECT_EQ(0, TF_OperationNumControlOutputs(scalar));
+ EXPECT_EQ(0, TF_OperationGetControlOutputs(scalar, control_ops, 100));
+
+ EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
+ EXPECT_EQ(0, TF_OperationGetControlInputs(feed, control_ops, 100));
+ EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
+ EXPECT_EQ(0, TF_OperationGetControlOutputs(feed, control_ops, 100));
+
+ EXPECT_EQ(0, TF_OperationNumControlInputs(neg));
+ EXPECT_EQ(0, TF_OperationGetControlInputs(neg, control_ops, 100));
+ EXPECT_EQ(0, TF_OperationNumControlOutputs(neg));
+ EXPECT_EQ(0, TF_OperationGetControlOutputs(neg, control_ops, 100));
+
// Import it again, with an input mapping, return outputs, and a return
// operation, into the same graph.
TF_DeleteImportGraphDefOptions(opts);
ASSERT_TRUE(neg2 != nullptr);
// Check input mapping
- TF_Output neg_input = TF_OperationInput({neg, 0});
+ neg_input = TF_OperationInput({neg, 0});
EXPECT_EQ(scalar, neg_input.oper);
EXPECT_EQ(0, neg_input.index);