Automated g4 rollback of changelist 190801044
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Mar 2018 21:59:53 +0000 (14:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 22:02:13 +0000 (15:02 -0700)
PiperOrigin-RevId: 190839672

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc
tensorflow/core/grappler/optimizers/graph_optimizer_stage.h

index 629872b..5dd0b6f 100644 (file)
@@ -196,6 +196,8 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
 
 bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
 
+const char kOutputShapesAttr[] = "_output_shapes";
+
 // Shape is symbolically defined if it has a known rank, and each dimension is
 // defined, or is an unknown symbol (dim.size <= -2).
 bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape) {
@@ -232,19 +234,16 @@ bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left,
 // Returns whether `reshape` is an identity op. The tensor that `reshape`
 // reshapes is the `output_pos`-th output of node `input`.
 bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
-                       const int output_pos,
-                       const GraphProperties& graph_properties) {
-  const std::vector<OpInfo::TensorProperties>& reshape_props =
-      graph_properties.GetOutputProperties(reshape.name());
-  const std::vector<OpInfo::TensorProperties>& input_props =
-      graph_properties.GetOutputProperties(input.name());
-  if (reshape_props.empty() || input_props.empty() ||
-      input_props.size() <= output_pos) {
+                       const int output_pos) {
+  if (!reshape.attr().count(kOutputShapesAttr) ||
+      !input.attr().count(kOutputShapesAttr)) {
     return false;
   }
 
-  const PartialTensorShape& src_shape = input_props[output_pos].shape();
-  const PartialTensorShape& dst_shape = reshape_props[0].shape();
+  PartialTensorShape src_shape(
+      input.attr().at(kOutputShapesAttr).list().shape(output_pos));
+  PartialTensorShape dst_shape(
+      reshape.attr().at(kOutputShapesAttr).list().shape(0));
   if (src_shape.unknown_rank() || dst_shape.unknown_rank()) {
     return false;
   }
@@ -1273,8 +1272,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
     // outputs tensors of shape [M, N] while feeding it with tensors of shape
     // [M*N] (or worse). The reshape nodes are then necessary to update the
     // tensor metadata to the required shape.
-    if (can_use_shapes_ &&
-        ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_)) {
+    if (ReshapeIsIdentity(*reshape, *input, output_pos)) {
       return reshape->input(0);
     }
   }
@@ -1588,11 +1586,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
 
   std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
 
-  if (options_.combine_add_to_addn && can_use_shapes_) {
+  if (options_.combine_add_to_addn) {
     stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
         new AddOpsRewriteStage(ctx, ctx_ext)));
   }
-  if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes_) {
+  if (options_.hoist_common_factor_out_of_aggregation) {
     stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
         new HoistCommonFactorOutOfAggregation(ctx, ctx_ext)));
   }
@@ -1629,15 +1627,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
     if (simplified_tensor.empty()) {
       for (auto& stage : stages) {
         if (stage->IsSupported(node)) {
-          const Status stage_status =
-              stage->TrySimplify(node, &simplified_tensor);
-          // Each stage must be "error safe" (just like exception safe). In
-          // case of any error it must leave optimized graph unmodified.
-          if (!stage_status.ok()) {
-            LOG(WARNING) << "Failed to run arithmetic optimizer stage "
-                         << stage->stage_name()
-                         << ". Error: " << stage_status.error_message();
-          }
+          TF_RETURN_IF_ERROR(stage->TrySimplify(node, &simplified_tensor));
           if (!simplified_tensor.empty()) {
             break;
           }
@@ -1704,16 +1694,19 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
                                                &frame_map_, &num_frames));
   // Shapes are only needed in aggressive mode.
   graph_properties_.reset(new GraphProperties(item));
-  const Status status = graph_properties_->InferStatically(false);
-  can_use_shapes_ = status.ok();
-  if (!can_use_shapes_) {
-    LOG(WARNING) << "Shape inference failed.";
-  }
+  TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
+  // TODO(ezhulenev): Use GraphProperties to lookup tensor shapes directly
+  TF_RETURN_IF_ERROR(graph_properties_->AnnotateOutputShapes(optimized_graph_));
 
   // Perform the optimizations.
   DedupComputations();
   TF_RETURN_IF_ERROR(SimplifyArithmeticOps());
 
+  // Clear output shapes.
+  for (int i = 0; i < optimized_graph->node_size(); ++i) {
+    optimized_graph_->mutable_node(i)->mutable_attr()->erase(kOutputShapesAttr);
+  }
+
   return Status::OK();
 }
 
index cdeed05..965f0e9 100644 (file)
@@ -126,7 +126,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
   RewriterConfig::Toggle opt_level_;
   ArithmeticOptimizerOptions options_;
 
-  bool can_use_shapes_ = false;
   bool fetch_nodes_known_ = false;
   std::unordered_set<string> nodes_to_preserve_;
   std::unique_ptr<NodeMap> node_map_;
index 1ea57f7..7044705 100644 (file)
@@ -42,10 +42,6 @@ Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
 Status GetTensorProperties(const GraphOptimizerContext& ctx,
                            const string& tensor,
                            OpInfo::TensorProperties* properties) {
-  if (ctx.graph_properties == nullptr) {
-    return errors::InvalidArgument("Graph properties are unknown.");
-  }
-
   int port;
   string tensor_node_name = ParseNodeName(tensor, &port);
   if (port < 0) {
index c7af82a..be95c00 100644 (file)
@@ -117,9 +117,6 @@ class GraphOptimizerStage {
       : optimizer_name_(optimizer_name), stage_name_(stage_name), ctx_(ctx) {}
   virtual ~GraphOptimizerStage() = default;
 
-  const string& stage_name() const { return stage_name_; }
-  const string& optimizer_name() const { return optimizer_name_; }
-
   // Check if we should try to simplify node. Returning true doesn't
   // guarantee that node will be simplified.
   //