Fix a few bugs in ArithmeticOptimizer and make it robust to failures of shape inference.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 19:18:04 +0000 (12:18 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 19:21:28 +0000 (12:21 -0700)
PiperOrigin-RevId: 191922788

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc

index 59a5695..7bf264b 100644 (file)
@@ -237,17 +237,16 @@ bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
     return false;
   }
 
-  // Now, src_shape and dst_shape have at most one dimension with unknown
-  // sizes, and are compatible. Therefore, the reshape is a no-op when
-  //
-  // 1. at least one of them is fully-defined, or
-  // 2. both are partially defined and the -1 appears on the same dimension,
-  //    i.e., IsIdenticalTo returns true.
-  if (src_num_unknown_dim_sizes == 1 && dst_num_unknown_dim_sizes == 1) {
-    return dst_shape.IsIdenticalTo(src_shape);
+  // If dst_num_unknown_dim_sizes != src_num_unknown_dim_sizes we would weaken
+  // shape inference in subsequent passes if we removed this reshape.
+  if (src_num_unknown_dim_sizes != dst_num_unknown_dim_sizes) {
+    return false;
   }
 
-  return true;
+  // Remove the reshape if both are fully defined or partially defined and the
+  // unknown or symbolic shape appears on the same dimension, i.e., if
+  // IsIdenticalTo returns true.
+  return dst_shape.IsIdenticalTo(src_shape);
 }
 
 NodeDef* GetTailOfValuePreservingChain(
@@ -727,7 +726,9 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
 
         // Hoist non-shared factors up into the new AddN node.
         for (int i = 0; i < unique_factors.size(); ++i) {
-          new_add_node->set_input(i, unique_factors[i]);
+          const string& unique_factor_i = unique_factors[i];
+          new_add_node->set_input(i, unique_factor_i);
+          ctx_.node_map->AddOutput(unique_factor_i, new_add_node->name());
         }
 
         // Add control deps on add node
@@ -859,13 +860,18 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
     NodeDef* node_perm;
     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
+    if (!IsConstant(*node_perm)) {
+      return Status::OK();
+    }
     std::vector<int64> node_perm_values;
     TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
-
     if (input->op() == node->op()) {
       // Remove pairs of transposes that cancel each other.
       NodeDef* input_perm;
       TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm));
+      if (!IsConstant(*input_perm)) {
+        return Status::OK();
+      }
       std::vector<int64> input_perm_values;
       TF_RETURN_IF_ERROR(GetPermutation(*input_perm, &input_perm_values));
       if (AreInversePermutations(node_perm_values, input_perm_values)) {
@@ -1337,9 +1343,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
     //      ^      |
     //      |      |
     //    input ---+
-    NodeDef* reshape = node_map_->GetNode(node->name());
+    NodeDef* reshape = const_cast<NodeDef*>(node);
     int output_pos = 0;
-    string input_node_name = ParseNodeName(node->input(0), &output_pos);
+    string input_node_name = ParseNodeName(reshape->input(0), &output_pos);
     const NodeDef* input = node_map_->GetNode(input_node_name);
     if (input->op() == "Reshape" && !HasControlInputs(*input)) {
       reshape->set_input(0, input->input(0));
@@ -1653,7 +1659,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
   return "";
 }
 
-Status ArithmeticOptimizer::SimplifyArithmeticOps() {
+Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
   SetVector<NodeDef*> nodes_to_simplify;
   nodes_to_simplify.Reserve(optimized_graph_->node_size());
   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
@@ -1668,11 +1674,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
   const auto stop = [](const string& result) { return !result.empty(); };
   GraphOptimizerStagePipeline<string> pipeline(stop);
 
-  if (options_.combine_add_to_addn)
+  if (options_.combine_add_to_addn && can_use_shapes)
     pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
-  if (options_.hoist_common_factor_out_of_aggregation)
+  if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes)
     pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
-  if (options_.remove_identity_transpose)
+  if (options_.remove_identity_transpose && can_use_shapes)
     pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
   if (options_.remove_redundant_bitcast)
     pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
@@ -1759,10 +1765,14 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
 
   // Shapes are only needed in aggressive mode.
   graph_properties_.reset(new GraphProperties(item));
-  TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
+  const Status status = graph_properties_->InferStatically(false);
+  const bool can_use_shapes = status.ok();
+  if (!can_use_shapes) {
+    VLOG(1) << "Shape inference failed." << status.error_message();
+  }
 
   // Perform the optimizations.
-  TF_RETURN_IF_ERROR(SimplifyArithmeticOps());
+  TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes));
 
   optimized_graph->Swap(optimized_graph_);
   return Status::OK();
index 7e81ed0..39b89de 100644 (file)
@@ -105,7 +105,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
 
   // Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse
   // transposes.
-  Status SimplifyArithmeticOps();
+  Status SimplifyArithmeticOps(bool can_use_shapes);
   // Tries to simplify the expression that roots at `node` and replaces the uses
   // of `node` to the simplified expression. Returns the name of the simplified
   // tensor (e.g. "split:1") or an emtpy string if no simplification is
index 7044705..1ea57f7 100644 (file)
@@ -42,6 +42,10 @@ Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
 Status GetTensorProperties(const GraphOptimizerContext& ctx,
                            const string& tensor,
                            OpInfo::TensorProperties* properties) {
+  if (ctx.graph_properties == nullptr) {
+    return errors::InvalidArgument("Graph properties are unknown.");
+  }
+
   int port;
   string tensor_node_name = ParseNodeName(tensor, &port);
   if (port < 0) {