Misc. small optimizations in Grappler and shape inference code.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 12 Apr 2018 23:59:57 +0000 (16:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 00:10:43 +0000 (17:10 -0700)
Impact on time per optimizer on inception graph:

model_pruner:          590 ms -> 550 ms   (-7%)
function_optimizer:    130 ms -> 130 ms   (-0%)
constant_folding:     7600 ms -> 7550 ms  (-0.7%)
arithmetic_optimizer: 1860 ms -> 1550 ms  (-20%)
loop_optimizer:        320 ms -> 320 ms   (-0%)
dependency_optimizer: 1300 ms -> 720 ms   (-45%)
layout:               1400 ms -> 1400 ms  (-0%)
memory_optimizer:     4200 ms -> 3540 ms  (-16%)
PiperOrigin-RevId: 192694528

12 files changed:
tensorflow/core/framework/shape_inference.cc
tensorflow/core/graph/graph_constructor.cc
tensorflow/core/grappler/costs/graph_memory.cc
tensorflow/core/grappler/grappler_item.cc
tensorflow/core/grappler/grappler_item.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/dependency_optimizer.cc
tensorflow/core/grappler/optimizers/memory_optimizer.cc
tensorflow/core/grappler/optimizers/meta_optimizer.cc
tensorflow/core/grappler/utils.cc
tensorflow/core/grappler/utils.h

index cc1ec47..229b4a4 100644 (file)
@@ -40,6 +40,7 @@ InferenceContext::InferenceContext(
     : graph_def_version_(graph_def_version),
       node_def_(CHECK_NOTNULL(node_def)) {
   std::vector<ShapeHandle> input_tensors_as_shape_handles;
+  input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
   for (const TensorShapeProto& p : input_tensors_as_shapes) {
     ShapeHandle shape;
     construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
@@ -50,6 +51,7 @@ InferenceContext::InferenceContext(
   }
   PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
   if (!construction_status_.ok()) return;
+  inputs_.reserve(input_shapes.size());
   for (const TensorShapeProto& p : input_shapes) {
     ShapeHandle shape;
     construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
@@ -93,6 +95,7 @@ InferenceContext::InferenceContext(
     : graph_def_version_(graph_def_version),
       node_def_(CHECK_NOTNULL(node_def)) {
   std::vector<ShapeHandle> input_tensors_as_shape_handles;
+  input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
   for (const PartialTensorShape& p : input_tensors_as_shapes) {
     ShapeHandle shape;
     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
@@ -103,6 +106,7 @@ InferenceContext::InferenceContext(
   }
   PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
   if (!construction_status_.ok()) return;
+  inputs_.reserve(input_shapes.size());
   for (const PartialTensorShape& p : input_shapes) {
     ShapeHandle shape;
     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
@@ -229,9 +233,7 @@ void InferenceContext::PreInputInit(
   for (const auto& e : output_name_map_) {
     num_outputs = std::max(num_outputs, e.second.second);
   }
-  for (int i = 0; i < num_outputs; ++i) {
-    outputs_.push_back(nullptr);
-  }
+  outputs_.assign(num_outputs, nullptr);
   output_handle_shapes_and_types_.resize(num_outputs);
 }
 
@@ -469,13 +471,15 @@ Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
   TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
 
   // Merge the prefix dims and create the new output shapes.
+  const int32 rank_s = Rank(s);
   std::vector<DimensionHandle> dims;
+  dims.reserve(std::max(rank, rank_s));
   dims.resize(rank);
   for (int i = 0; i < rank; ++i) {
     TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
   }
   *prefix_out = MakeShape(dims);
-  for (int i = rank; i < Rank(s); ++i) dims.push_back(Dim(s, i));
+  for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
   *s_out = MakeShape(dims);
   return Status::OK();
 }
@@ -1105,6 +1109,7 @@ Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
 
 Status InferenceContext::AttachContext(const Status& status) {
   std::vector<string> input_shapes;
+  input_shapes.reserve(inputs_.size());
   for (const ShapeHandle& input_shape : inputs_) {
     input_shapes.emplace_back(DebugString(input_shape));
   }
@@ -1112,6 +1117,7 @@ Status InferenceContext::AttachContext(const Status& status) {
   // Add information about the input tensors and partial tensor shapes used.
   std::vector<string> input_from_tensors_str;
   std::vector<string> input_from_tensors_as_shape_str;
+  input_from_tensors_as_shape_str.reserve(inputs_.size());
   for (int i = 0; i < inputs_.size(); ++i) {
     if (requested_input_tensor_as_partial_shape_[i] &&
         i < input_tensors_as_shapes_.size() &&
@@ -1233,9 +1239,7 @@ bool InferenceContext::RelaxHandleShapesAndMergeTypes(
   if (!refined) {
     return false;
   }
-  for (int i = 0; i < new_values.size(); ++i) {
-    (*to_update)[i] = new_values[i];
-  }
+  to_update->swap(new_values);
   return true;
 }
 
index 250992f..c678283 100644 (file)
@@ -666,20 +666,17 @@ Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
                   std::vector<bool>* input_already_exists) {
   // Remove 'inputs_to_remove' from 'node_def'
-  // TODO(skyewm): is there a better way to do this?
-  std::vector<string> inputs;
-  inputs.reserve(node_def->input_size());
-  for (int i = 0; i < node_def->input_size(); ++i) {
-    inputs.push_back(node_def->input(i));
-  }
-  node_def->clear_input();
-  for (int i = 0, j = 0; i < inputs.size(); ++i) {
+  NodeDef copy;
+  copy.mutable_input()->Reserve(node_def->input_size() -
+                                inputs_to_remove.size());
+  for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
     if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
       ++j;
     } else {
-      node_def->add_input(inputs[i]);
+      copy.add_input()->swap(*node_def->mutable_input(i));
     }
   }
+  node_def->mutable_input()->Swap(copy.mutable_input());
   // Remove 'inputs_to_remove' from 'input_already_exists'
   for (int idx : inputs_to_remove) {
     input_already_exists->erase(input_already_exists->begin() + idx);
@@ -745,9 +742,21 @@ void GraphConstructor::AddControlDependencies(
   // dependencies
   for (const string& control_dep : opts_.control_dependencies) {
     string input = TensorId(control_dep, Graph::kControlSlot).ToString();
-    const protobuf::RepeatedPtrField<string>& inputs = node_def->input();
-    if (std::find(inputs.begin(), inputs.end(), input) != inputs.end()) {
-      // Control dependency already exists
+    bool found = false;
+    for (int i = node_def->input_size() - 1; i >= 0; --i) {
+      const string& node_input = node_def->input(i);
+      if (node_input[0] != '^') {
+        // Control inputs are at the end. Break when we reach the non-control
+        // inputs.
+        break;
+      }
+      if (node_input == input) {
+        // Control dependency already exists
+        found = true;
+        break;
+      }
+    }
+    if (found) {
       continue;
     }
     node_def->add_input(input);
@@ -761,10 +770,10 @@ void GraphConstructor::AddPrefixToNodeDef(
   node_def->set_name(strings::StrCat(prefix_, node_def->name()));
   // Update names of input nodes
   for (int i = 0; i < node_def->input_size(); ++i) {
-    StringPiece input(node_def->input(i));
     // Skip remapped inputs (which already exist in g_ and are not being
     // imported).
     if (input_already_exists[i]) continue;
+    StringPiece input(node_def->input(i));
     if (str_util::ConsumePrefix(&input, "^")) {
       node_def->set_input(i, strings::StrCat("^", prefix_, input));
     } else {
@@ -933,10 +942,10 @@ Status GraphConstructor::Convert() {
         }
       }
 
-      // TODO(ashankar): The line below means an additional copy of the NodeDef,
-      // which can be expensive if the NodeDef contains large tensors in it.
-      // Might make sense to change the API for ImportGraphDef to take a mutable
-      // GraphDef* and avoid the copying.
+      // TODO(ashankar): The line below means an additional copy of the
+      // NodeDef, which can be expensive if the NodeDef contains large tensors
+      // in it. Might make sense to change the API for ImportGraphDef to take
+      // a mutable GraphDef* and avoid the copying.
       imported_node_def = original_node_def;
       if (!opts_.input_map.empty()) {
         // Note that input_already_exists can shrink here
@@ -980,7 +989,7 @@ Status GraphConstructor::Convert() {
             src_node->num_outputs(), " outputs");
       }
 
-      inputs.push_back(InputInfo(id.first.ToString(), src_node, src_index));
+      inputs.emplace_back(id.first.ToString(), src_node, src_index);
     }
 
     if (has_data_back_edge && !IsMerge(*node_def)) {
@@ -1010,8 +1019,7 @@ Status GraphConstructor::Convert() {
       if (inputs[i].node == nullptr) {
         // Record this back edge, which will be added after all nodes
         // are created.
-        back_edges_.push_back(
-            EdgeInfo(inputs[i].name, inputs[i].index, node, i));
+        back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
       } else if (inputs[i].index == Graph::kControlSlot) {
         g_->AddControlEdge(inputs[i].node, node);
       } else {
index 3604de3..a5736d4 100644 (file)
@@ -14,7 +14,8 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/grappler/costs/graph_memory.h"
-#include <list>
+
+#include <deque>
 #include "tensorflow/core/framework/allocation_description.pb.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
@@ -120,7 +121,7 @@ int64 GraphMemory::InferMemUsageForNeighbors(
 static GraphMemory::LiveTensor* FindOrCreateLiveTensor(
     const string& node_name, int output_id,
     std::unordered_map<string, GraphMemory::LiveTensor*>* live_tensors,
-    std::list<GraphMemory::LiveTensor>* device_tensors) {
+    std::deque<GraphMemory::LiveTensor>* device_tensors) {
   string name = strings::StrCat(node_name, ":", output_id);
   GraphMemory::LiveTensor* live;
   auto it = live_tensors->find(name);
@@ -141,6 +142,10 @@ static GraphMemory::LiveTensor* FindOrCreateLiveTensor(
 
 namespace {
 struct Event {
+  Event(int64 _timestamp, bool _allocated,
+        const GraphMemory::LiveTensor* _tensor)
+      : timestamp(_timestamp), allocated(_allocated), tensor(_tensor) {}
+
   int64 timestamp;
   bool allocated;
   const GraphMemory::LiveTensor* tensor;
@@ -160,13 +165,15 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) {
   }
 
   std::unordered_map<string, LiveTensor*> live_tensors;
-  std::unordered_map<string, std::list<LiveTensor>> live_tensors_per_device;
-
-  NodeMap node_map(&item_.graph);
+  std::unordered_map<string, std::deque<LiveTensor>> live_tensors_per_device;
+  std::unordered_map<string, const NodeDef*> node_map;
+  for (const NodeDef& node : item_.graph.node()) {
+    node_map[node.name()] = &node;
+  }
   for (const auto& dev_stats : timeline.dev_stats()) {
     const string& device_name = dev_stats.device();
     const bool is_gpu = (device_name.find("GPU:") || device_name.find("gpu:"));
-    std::list<LiveTensor>& device_tensors =
+    std::deque<LiveTensor>& device_tensors =
         live_tensors_per_device[dev_stats.device()];
     for (const auto& node_stats : dev_stats.node_stats()) {
       for (int i = 0; i < node_stats.output_size(); ++i) {
@@ -191,12 +198,13 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) {
                                     node_stats.op_end_rel_micros()));
       }
 
-      const NodeDef* node = node_map.GetNode(node_stats.node_name());
-      if (!node) {
+      auto it = node_map.find(node_stats.node_name());
+      if (it == node_map.end()) {
         // Skip nodes inserted by TF since they don't exist in the original
         // graph (e.g _Send/_Recv nodes).
         continue;
       }
+      const NodeDef* node = it->second;
       std::unordered_set<int> swapped_inputs;
       if (is_gpu) {
         auto it = node->attr().find("_swap_to_host");
@@ -237,14 +245,16 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) {
     std::vector<Event> events;
     events.reserve(2 * live_per_device.second.size());
     for (const auto& live : live_per_device.second) {
-      events.push_back(Event{live.allocation_time.count(), true, &live});
-      events.push_back(Event{live.deallocation_time.count(), false, &live});
+      events.emplace_back(static_cast<int64>(live.allocation_time.count()),
+                          true, &live);
+      events.emplace_back(static_cast<int64>(live.deallocation_time.count()),
+                          false, &live);
     }
     std::stable_sort(events.begin(), events.end());
     size_t peak = 0;
-    std::set<const LiveTensor*> live_at_peak;
+    std::unordered_set<const LiveTensor*> live_at_peak;
     size_t current = 0;
-    std::set<const LiveTensor*> currently_live;
+    std::unordered_set<const LiveTensor*> currently_live;
     for (int i = 0; i < events.size(); ++i) {
       const auto& event = events[i];
 
index ad86356..bbc0fed 100644 (file)
@@ -27,7 +27,7 @@ limitations under the License.
 namespace tensorflow {
 namespace grappler {
 
-GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef&& graphDef) {
+GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
   id = other.id;
   feed = other.feed;
   fetch = other.fetch;
@@ -38,7 +38,7 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef&& graphDef) {
   restore_op = other.restore_op;
   save_restore_loc_tensor = other.save_restore_loc_tensor;
   queue_runners = other.queue_runners;
-  graph.Swap(&graphDef);
+  graph.Swap(graph_def);
 }
 
 std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
index 45eed47..cd165ac 100644 (file)
@@ -33,10 +33,12 @@ namespace grappler {
 // A TensorFlow model to optimize.
 // Models are represented by the combination of a graph, one of more fetch
 // nodes, and potentially a set of nodes to feed.
-// TODO(volunteer_needed): turn this struct into a class.
 struct GrapplerItem {
   GrapplerItem() = default;
-  GrapplerItem(const GrapplerItem& other, GraphDef&& graphDef);
+  GrapplerItem(const GrapplerItem& other, GraphDef&& graph_def)
+      : GrapplerItem(other, &graph_def) {}
+  // Swaps *graph_def with an empty GraphDef.
+  GrapplerItem(const GrapplerItem& other, GraphDef* graph_def);
   virtual ~GrapplerItem() = default;
 
   string id;  // A unique id for this item
index 463c332..60b1af4 100644 (file)
@@ -253,9 +253,8 @@ NodeDef* GetTailOfValuePreservingChain(
     const NodeDef& node, const NodeMap& node_map,
     const std::unordered_set<string>& nodes_to_preserve) {
   auto is_value_preserving_non_branching = [&](const NodeDef& node) {
-    return IsValuePreserving(node) &&
-           NumNonControlOutputs(node, node_map) == 1 &&
-           nodes_to_preserve.count(node.name()) == 0;
+    return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
+           IsValuePreserving(node) && NumNonControlOutputs(node, node_map) == 1;
   };
   return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
                         is_value_preserving_non_branching);
@@ -2023,12 +2022,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
 Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
                                      const GrapplerItem& item,
                                      GraphDef* optimized_graph) {
-  GrapplerItem optimized_item(item);
-  optimized_graph_ = &optimized_item.graph;
-
   // Set up helper data structures.
   nodes_to_preserve_ = item.NodesToPreserve();
   fetch_nodes_known_ = !item.fetch.empty();
+  *optimized_graph = item.graph;
+  optimized_graph_ = optimized_graph;
   node_map_.reset(new NodeMap(optimized_graph_));
 
   DedupComputations();
@@ -2037,8 +2035,9 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
   // optimize larger subgraphs starting from the roots with more inputs.
   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
 
-  // Shapes are only needed in aggressive mode.
-  graph_properties_.reset(new GraphProperties(item));
+  GrapplerItem optimized_item(item, optimized_graph);
+  optimized_graph_ = &optimized_item.graph;
+  graph_properties_.reset(new GraphProperties(optimized_item));
   const Status status = graph_properties_->InferStatically(false);
   const bool can_use_shapes = status.ok();
   if (!can_use_shapes) {
index b2a1ce6..e29aaa2 100644 (file)
@@ -1004,7 +1004,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
 
   for (const auto& input : node.input()) {
     int port = 0;
-    ParseNodeName(input, &port);
+    ParseNodeNameAsStringPiece(input, &port);
     if (port < 0) {
       // Control dependency
       break;
@@ -2084,9 +2084,9 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
           left_child_is_constant ? left_child : right_child;
       // Make sure that it is safe to change the value of the child node->
       if (op_child_node->input_size() < 2 ||
-          NumNonControlOutputs(*op_child_node, *node_map_) > 1 ||
           nodes_to_preserve_.find(op_child_node->name()) !=
-              nodes_to_preserve_.end()) {
+              nodes_to_preserve_.end() ||
+          NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
         continue;
       }
 
index ed9bce4..7b7fd81 100644 (file)
@@ -109,23 +109,12 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) {
 }
 
 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
-  if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
-    return false;
-  }
-  if (!fetch_nodes_known_ || NumNonControlOutputs(node, *node_map_) > 0) {
-    // The output values of this node may be needed.
-    return false;
-  }
-  if (IsMerge(node) || IsSwitch(node)) {
-    return false;
-  }
-  if (ModifiesFrameInfo(node)) {
-    return false;
-  }
-  if (!IsFreeOfSideEffect(node)) {
+  if (!fetch_nodes_known_ ||
+      nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
     return false;
   }
-  if (node.op() == "ControlTrigger") {
+  if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node) ||
+      !IsFreeOfSideEffect(node)) {
     return false;
   }
   if (node.op().rfind("Submodel", 0) == 0) {
@@ -136,16 +125,21 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
   if (!status.ok() || op_def->output_arg_size() == 0) {
     return false;
   }
-
+  const std::unordered_set<string> do_not_rewrite_ops{
+      "Assert",      "CheckNumerics",         "_Retval",
+      "_Arg",        "_ParallelConcatUpdate", "_TPUExecute",
+      "_TPUCompile", "ControlTrigger"};
+  if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
+    return false;
+  }
   if (!SafeToRemoveIdentity(node)) {
     return false;
   }
-
-  const std::unordered_set<string> do_not_rewrite_ops{
-      "Assert",     "CheckNumerics",         "_Retval",
-      "_Arg",       "_ParallelConcatUpdate", "_TPUExecute",
-      "_TPUCompile"};
-  return do_not_rewrite_ops.find(node.op()) == do_not_rewrite_ops.end();
+  if (NumNonControlOutputs(node, *node_map_) > 0) {
+    // The output values of this node may be needed.
+    return false;
+  }
+  return true;
 }
 
 void DependencyOptimizer::OptimizeNode(int node_idx,
@@ -164,7 +158,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
       bool data_connection = false;
       for (int i = fanout->input_size() - 1; i >= 0; --i) {
         int pos;
-        string input_name = ParseNodeName(fanout->input(i), &pos);
+        StringPiece input_name =
+            ParseNodeNameAsStringPiece(fanout->input(i), &pos);
         if (input_name == node_name) {
           if (pos < 0) {
             fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1);
@@ -358,8 +353,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
           for (int j = 0; j < consumer->input_size(); ++j) {
             const string& old_input = consumer->input(j);
             int old_input_pos;
-            string old_input_node_name =
-                ParseNodeName(old_input, &old_input_pos);
+            StringPiece old_input_node_name =
+                ParseNodeNameAsStringPiece(old_input, &old_input_pos);
             if (old_input_node_name == node_name) {
               if (old_input_pos >= 0) {
                 // Regular input
index 27e9d2c..c1fee0e 100644 (file)
@@ -1227,7 +1227,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
                              recomputation_targets_name_scope_, optimized_graph,
                              item);
 
-  GrapplerItem optimized_item(item, std::move(*optimized_graph));
+  GrapplerItem optimized_item(item, optimized_graph);
   std::unordered_set<string> skip_list;
   // Bound the number of rewrite passes to avoid long processing times on graphs
   // that simply won't fit in memory.
index 5723e39..558b8a7 100644 (file)
@@ -178,45 +178,41 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
       cfg_.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
           ? 1
           : cfg_.meta_optimizer_iterations();
+  GrapplerItem optimized_item = item;
+  optimized_graph->Swap(&optimized_item.graph);
   for (int iteration = 0; iteration < num_iterations; ++iteration) {
     VLOG(1) << "Starting optimization iteration " << iteration + 1;
     for (const auto& optimizer : optimizers) {
+      // Invariant: optimized_graph contains the most recently optimized
+      // version of the graph.
       if (iteration > 0 && run_once_optimizers.count(optimizer->name())) {
         continue;
       }
-      if (!already_optimized) {
-        Status status = optimizer->Optimize(cluster, item, optimized_graph);
-        string result;
-        if (!status.ok()) {
-          VLOG(1) << "Not able to apply optimizer " << optimizer->name()
-                  << ". Return status: " << status.ToString();
-          result = status.ToString();
-        } else {
-          already_optimized = true;
-          result = strings::StrCat(
-              "OK. ", PrintSizesBeforeAfter(item.graph, *optimized_graph));
-        }
-        result_.push_back(std::make_pair(optimizer->name(), result));
-        VLOG(1) << "Optimizer " << optimizer->name()
-                << " return status: " << result;
+      uint64 start_us = Env::Default()->NowMicros();
+      // This swaps the current optimized_graph into optimized item and
+      // resets optimized_graph to an empty graph.
+      optimized_graph->Swap(&optimized_item.graph);
+      *optimized_graph = GraphDef();
+      Status status =
+          optimizer->Optimize(cluster, optimized_item, optimized_graph);
+
+      uint64 end_us = Env::Default()->NowMicros();
+      float duration_ms = (end_us - start_us) / 1000.0f;
+      string result;
+      if (!status.ok()) {
+        VLOG(1) << "Not able to apply optimizer " << optimizer->name() << ": "
+                << status.ToString();
+        optimized_graph->Swap(&optimized_item.graph);
+        result = status.ToString();
       } else {
-        GrapplerItem optimized_item(item, std::move(*optimized_graph));
-        Status status =
-            optimizer->Optimize(cluster, optimized_item, optimized_graph);
-        string result;
-        if (!status.ok()) {
-          VLOG(1) << "Not able to apply optimizer " << optimizer->name() << ": "
-                  << status.ToString();
-          optimized_graph->Swap(&optimized_item.graph);
-          result = status.ToString();
-        } else {
-          result = strings::StrCat(
-              optimizer->name(), ": ",
-              PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph));
-        }
-        result_.push_back(std::make_pair(optimizer->name(), result));
-        VLOG(1) << result;
+        already_optimized = true;
+        result = strings::StrCat(
+            optimizer->name(), ": ",
+            PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph),
+            ", time = ", duration_ms, "ms.");
       }
+      result_.emplace_back(optimizer->name(), result);
+      VLOG(1) << result;
     }
   }
 
@@ -230,10 +226,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
               item.graph.library().gradient_size());
     DCHECK_EQ(optimized_graph->versions().producer(),
               item.graph.versions().producer());
-  } else {
-    *optimized_graph = item.graph;
   }
-
   return Status::OK();
 }
 
index 534fe67..7398d2c 100644 (file)
@@ -142,38 +142,12 @@ bool IsSameInput(const string& name1, const string& name2) {
     return true;
   }
   int position1;
-  string node1 = ParseNodeName(name1, &position1);
+  StringPiece node1 = ParseNodeNameAsStringPiece(name1, &position1);
   int position2;
-  string node2 = ParseNodeName(name2, &position2);
+  StringPiece node2 = ParseNodeNameAsStringPiece(name2, &position2);
   return (position1 == position2) && (node1 == node2);
 }
 
-string ParseNodeName(const string& name, int* position) {
-  // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
-  // to get a node name.
-  strings::Scanner scan(name);
-  scan.ZeroOrOneLiteral("^")
-      .RestartCapture()
-      .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
-      .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
-  StringPiece capture;
-  StringPiece remaining;
-  if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
-    *position = 0;
-    return "";
-  } else {
-    if (name[0] == '^') {
-      *position = -1;
-    } else if (remaining.empty()) {
-      *position = 0;
-    } else {
-      // Skip the first ':' character.
-      CHECK(strings::safe_strto32(remaining.substr(1), position));
-    }
-    return capture.ToString();
-  }
-}
-
 bool IsControlInput(const string& name) {
   return !name.empty() && name[0] == '^';
 }
@@ -185,7 +159,7 @@ string NodeName(const string& name) {
 
 int NodePosition(const string& name) {
   int position;
-  ParseNodeName(name, &position);
+  ParseNodeNameAsStringPiece(name, &position);
   return position;
 }
 
@@ -275,13 +249,20 @@ int NumNonControlInputs(const NodeDef& node) {
 
 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
   int num_outputs = 0;
+  int pos;
   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
     for (const string& node_as_input : output->input()) {
       if (IsControlInput(node_as_input)) {
         break;
       }
-      if (NodeName(node_as_input) == node.name()) {
+      if (node_as_input == node.name()) {
         ++num_outputs;
+      } else {
+        const StringPiece name =
+            ParseNodeNameAsStringPiece(node_as_input, &pos);
+        if (name == node.name()) {
+          ++num_outputs;
+        }
       }
     }
   }
index 11555d7..b15667d 100644 (file)
@@ -26,8 +26,10 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/strings/scanner.h"
 
 namespace tensorflow {
 namespace grappler {
@@ -107,8 +109,38 @@ string NodeName(const string& name);
 // Get the trailing position number ":{digits}" (if any) of a node name.
 int NodePosition(const string& name);
 
+inline StringPiece ParseNodeNameAsStringPiece(const string& name,
+                                              int* position) {
+  // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
+  // to get a node name.
+  strings::Scanner scan(name);
+  scan.ZeroOrOneLiteral("^")
+      .RestartCapture()
+      .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
+      .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
+  StringPiece capture;
+  StringPiece remaining;
+  if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
+    *position = 0;
+    static const string empty;
+    return StringPiece(empty);
+  } else {
+    if (name[0] == '^') {
+      *position = -1;
+    } else if (remaining.empty()) {
+      *position = 0;
+    } else {
+      // Skip the first ':' character.
+      CHECK(strings::safe_strto32(remaining.substr(1), position));
+    }
+    return capture;
+  }
+}
+
 // Returns the node name and position in a single call.
-string ParseNodeName(const string& name, int* position);
+inline string ParseNodeName(const string& name, int* position) {
+  return ParseNodeNameAsStringPiece(name, position).ToString();
+}
 
 // Add a prefix to a node name with a custom delimiter.
 string AddPrefixToNodeName(const string& name, const string& prefix,