// more with the original node name.
for (const auto& fetch : item.fetch) {
const NodeDef* fetch_node = node_map_->GetNode(fetch);
- if (fetch_node && NumOutputs(*fetch_node) == 1) {
+ if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
nodes_whitelist_.insert(fetch_node->name());
}
}
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
: strings::StrCat("^", node_name);
}
-int NumOutputs(const NodeDef& node) {
+int NumOutputs(const NodeDef& node, GraphDef* graph) {
int num_outputs = 0;
const OpDef* op_def = nullptr;
auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
num_outputs++;
}
}
+ } else {
+ FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
+ auto status = fdef.LookUpOpDef(node.op(), &op_def);
+ if (status.ok()) {
+ num_outputs = op_def->output_arg_size();
+ }
}
return num_outputs;
}
// Returns the number of outputs of a node according to its OpDef. Note that
// some of the outputs may be unconnected.
-int NumOutputs(const NodeDef& node);
+int NumOutputs(const NodeDef& node, GraphDef* graph);
// Number of connected non-control inputs.
int NumNonControlInputs(const NodeDef& node);
}
TEST_F(UtilsTest, NumOutputs) {
- EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode()));
- EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode()));
- EXPECT_EQ(1, NumOutputs(CreateDequeueNode()));
+ GraphDef graph;
+ EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode(), &graph));
+ EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode(), &graph));
+ EXPECT_EQ(1, NumOutputs(CreateDequeueNode(), &graph));
}
TEST_F(UtilsTest, AsControlDependency) {