From: A. Unique TensorFlower Date: Fri, 6 Apr 2018 19:18:04 +0000 (-0700) Subject: Fix a few bugs in ArithmeticOptimizer and make it robust to failures of shape inference. X-Git-Tag: tflite-v0.1.7~16^2^2~91 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4d90c62824a2e4e445efab58d2c5829774a884ea;p=platform%2Fupstream%2Ftensorflow.git Fix a few bugs in ArithmeticOptimizer and make it robust to failures of shape inference. PiperOrigin-RevId: 191922788 --- diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 59a5695..7bf264b 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -237,17 +237,16 @@ bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input, return false; } - // Now, src_shape and dst_shape have at most one dimension with unknown - // sizes, and are compatible. Therefore, the reshape is a no-op when - // - // 1. at least one of them is fully-defined, or - // 2. both are partially defined and the -1 appears on the same dimension, - // i.e., IsIdenticalTo returns true. - if (src_num_unknown_dim_sizes == 1 && dst_num_unknown_dim_sizes == 1) { - return dst_shape.IsIdenticalTo(src_shape); + // If dst_num_unknown_dim_sizes != src_num_unknown_dim_sizes we would weaken + // shape inference in subsequent passes if we removed this reshape. + if (src_num_unknown_dim_sizes != dst_num_unknown_dim_sizes) { + return false; } - return true; + // Remove the reshape if both are fully defined or partially defined and the + // unknown or symbolic shape appears on the same dimension, i.e., if + // IsIdenticalTo returns true. + return dst_shape.IsIdenticalTo(src_shape); } NodeDef* GetTailOfValuePreservingChain( @@ -727,7 +726,9 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { // Hoist non-shared factors up into the new AddN node. for (int i = 0; i < unique_factors.size(); ++i) { - new_add_node->set_input(i, unique_factors[i]); + const string& unique_factor_i = unique_factors[i]; + new_add_node->set_input(i, unique_factor_i); + ctx_.node_map->AddOutput(unique_factor_i, new_add_node->name()); } // Add control deps on add node @@ -859,13 +860,18 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); NodeDef* node_perm; TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm)); + if (!IsConstant(*node_perm)) { + return Status::OK(); + } std::vector node_perm_values; TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values)); - if (input->op() == node->op()) { // Remove pairs of transposes that cancel each other. NodeDef* input_perm; TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm)); + if (!IsConstant(*input_perm)) { + return Status::OK(); + } std::vector input_perm_values; TF_RETURN_IF_ERROR(GetPermutation(*input_perm, &input_perm_values)); if (AreInversePermutations(node_perm_values, input_perm_values)) { @@ -1337,9 +1343,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // ^ | // | | // input ---+ - NodeDef* reshape = node_map_->GetNode(node->name()); + NodeDef* reshape = const_cast(node); int output_pos = 0; - string input_node_name = ParseNodeName(node->input(0), &output_pos); + string input_node_name = ParseNodeName(reshape->input(0), &output_pos); const NodeDef* input = node_map_->GetNode(input_node_name); if (input->op() == "Reshape" && !HasControlInputs(*input)) { reshape->set_input(0, input->input(0)); @@ -1653,7 +1659,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( return ""; } -Status ArithmeticOptimizer::SimplifyArithmeticOps() { +Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { SetVector nodes_to_simplify; nodes_to_simplify.Reserve(optimized_graph_->node_size()); for (int i = 0; i < optimized_graph_->node_size(); ++i) { @@ -1668,11 +1674,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { const auto stop = [](const string& result) { return !result.empty(); }; GraphOptimizerStagePipeline pipeline(stop); - if (options_.combine_add_to_addn) + if (options_.combine_add_to_addn && can_use_shapes) pipeline.AddStage(ctx, ctx_ext); - if (options_.hoist_common_factor_out_of_aggregation) + if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes) pipeline.AddStage(ctx, ctx_ext); - if (options_.remove_identity_transpose) + if (options_.remove_identity_transpose && can_use_shapes) pipeline.AddStage(ctx, ctx_ext); if (options_.remove_redundant_bitcast) pipeline.AddStage(ctx, ctx_ext); @@ -1759,10 +1765,14 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, // Shapes are only needed in aggressive mode. graph_properties_.reset(new GraphProperties(item)); - TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false)); + const Status status = graph_properties_->InferStatically(false); + const bool can_use_shapes = status.ok(); + if (!can_use_shapes) { + VLOG(1) << "Shape inference failed." << status.error_message(); + } // Perform the optimizations. - TF_RETURN_IF_ERROR(SimplifyArithmeticOps()); + TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes)); optimized_graph->Swap(optimized_graph_); return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 7e81ed0..39b89de 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -105,7 +105,7 @@ class ArithmeticOptimizer : public GraphOptimizer { // Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse // transposes. - Status SimplifyArithmeticOps(); + Status SimplifyArithmeticOps(bool can_use_shapes); // Tries to simplify the expression that roots at `node` and replaces the uses // of `node` to the simplified expression. Returns the name of the simplified // tensor (e.g. "split:1") or an emtpy string if no simplification is diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc index 7044705..1ea57f7 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc @@ -42,6 +42,10 @@ Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, Status GetTensorProperties(const GraphOptimizerContext& ctx, const string& tensor, OpInfo::TensorProperties* properties) { + if (ctx.graph_properties == nullptr) { + return errors::InvalidArgument("Graph properties are unknown."); + } + int port; string tensor_node_name = ParseNodeName(tensor, &port); if (port < 0) {