Backward pass implementation for fusion optimizer.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 5 Feb 2018 22:43:35 +0000 (14:43 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 22:47:27 +0000 (14:47 -0800)
PiperOrigin-RevId: 184589487

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/utils.cc
tensorflow/core/grappler/utils.h
tensorflow/core/grappler/utils_test.cc

index 0aeff62..37a4759 100644 (file)
@@ -1658,7 +1658,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
   // more with the original node name.
   for (const auto& fetch : item.fetch) {
     const NodeDef* fetch_node = node_map_->GetNode(fetch);
-    if (fetch_node && NumOutputs(*fetch_node) == 1) {
+    if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
       nodes_whitelist_.insert(fetch_node->name());
     }
   }
index 8099214..634577e 100644 (file)
@@ -17,6 +17,7 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/op_def.pb.h"
 #include "tensorflow/core/framework/types.h"
@@ -207,7 +208,7 @@ string AsControlDependency(const string& node_name) {
              : strings::StrCat("^", node_name);
 }
 
-int NumOutputs(const NodeDef& node) {
+int NumOutputs(const NodeDef& node, GraphDef* graph) {
   int num_outputs = 0;
   const OpDef* op_def = nullptr;
   auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
@@ -222,6 +223,12 @@ int NumOutputs(const NodeDef& node) {
         num_outputs++;
       }
     }
+  } else {
+    FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
+    auto status = fdef.LookUpOpDef(node.op(), &op_def);
+    if (status.ok()) {
+      num_outputs = op_def->output_arg_size();
+    }
   }
   return num_outputs;
 }
index c04a9a6..8840c44 100644 (file)
@@ -135,7 +135,7 @@ string AsControlDependency(const string& node);
 
 // Returns the number of outputs of a node according to its OpDef. Note that
 // some of the outputs may be unconnected.
-int NumOutputs(const NodeDef& node);
+int NumOutputs(const NodeDef& node, GraphDef* graph);
 
 // Number of connected non-control inputs.
 int NumNonControlInputs(const NodeDef& node);
index 77371c3..ba4e6b1 100644 (file)
@@ -177,9 +177,10 @@ TEST_F(UtilsTest, ExecuteWithTimeout) {
 }
 
 TEST_F(UtilsTest, NumOutputs) {
-  EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode()));
-  EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode()));
-  EXPECT_EQ(1, NumOutputs(CreateDequeueNode()));
+  GraphDef graph;
+  EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode(), &graph));
+  EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode(), &graph));
+  EXPECT_EQ(1, NumOutputs(CreateDequeueNode(), &graph));
 }
 
 TEST_F(UtilsTest, AsControlDependency) {