From: A. Unique TensorFlower Date: Mon, 5 Feb 2018 22:43:35 +0000 (-0800) Subject: Backward pass implementation for fusion optimizer. X-Git-Tag: upstream/v1.7.0~31^2~999 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1c762f70caf7004470cdfa599b4eb7a76e5bcc78;p=platform%2Fupstream%2Ftensorflow.git Backward pass implementation for fusion optimizer. PiperOrigin-RevId: 184589487 --- diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 0aeff62..37a4759 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1658,7 +1658,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, // more with the original node name. for (const auto& fetch : item.fetch) { const NodeDef* fetch_node = node_map_->GetNode(fetch); - if (fetch_node && NumOutputs(*fetch_node) == 1) { + if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) { nodes_whitelist_.insert(fetch_node->name()); } } diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 8099214..634577e 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" @@ -207,7 +208,7 @@ string AsControlDependency(const string& node_name) { : strings::StrCat("^", node_name); } -int NumOutputs(const NodeDef& node) { +int NumOutputs(const NodeDef& node, GraphDef* graph) { int num_outputs = 0; const OpDef* op_def = nullptr; auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); @@ -222,6 +223,12 @@ int NumOutputs(const NodeDef& node) { num_outputs++; } } + } else { + FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library()); + auto status = fdef.LookUpOpDef(node.op(), &op_def); + if (status.ok()) { + num_outputs = op_def->output_arg_size(); + } } return num_outputs; } diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index c04a9a6..8840c44 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -135,7 +135,7 @@ string AsControlDependency(const string& node); // Returns the number of outputs of a node according to its OpDef. Note that // some of the outputs may be unconnected. -int NumOutputs(const NodeDef& node); +int NumOutputs(const NodeDef& node, GraphDef* graph); // Number of connected non-control inputs. int NumNonControlInputs(const NodeDef& node); diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 77371c3..ba4e6b1 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -177,9 +177,10 @@ TEST_F(UtilsTest, ExecuteWithTimeout) { } TEST_F(UtilsTest, NumOutputs) { - EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode())); - EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode())); - EXPECT_EQ(1, NumOutputs(CreateDequeueNode())); + GraphDef graph; + EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode(), &graph)); + EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode(), &graph)); + EXPECT_EQ(1, NumOutputs(CreateDequeueNode(), &graph)); } TEST_F(UtilsTest, AsControlDependency) {