Demystify MaterializeShapes a bit.
authorMax Galkin <maxgalkin@google.com>
Tue, 13 Mar 2018 01:35:15 +0000 (18:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Mar 2018 01:39:38 +0000 (18:39 -0700)
PiperOrigin-RevId: 188812445

tensorflow/core/grappler/optimizers/constant_folding.cc

index 4c9431d..a4d8376 100644 (file)
@@ -244,44 +244,41 @@ string ConstantFolding::AddControlDependency(const string& input_name,
   }
 }
 
-Status ConvertShapeToConstant(const string& op, const DataType& type,
-                              const PartialTensorShape& shp, Tensor* value) {
+// Puts the given value into the tensor at the given "flat" index.
+static Status PutValueIntoTensor(const int64 value, const DataType& type,
+                                 const int index, Tensor* tensor) {
+  if (type == DT_INT32) {
+    if (value >= INT_MAX) {
+      return Status(error::INVALID_ARGUMENT, "int32 overflow");
+    }
+    tensor->flat<int32>()(index) = static_cast<int32>(value);
+  } else {
+    tensor->flat<int64>()(index) = value;
+  }
+  return Status::OK();
+}
+
+// Writes the given tensor shape into the given tensor.
+// Op is assumed to be Shape, ShapeN, Size or Rank.
+static Status ConvertShapeToConstant(const string& op, const DataType& type,
+                                     const PartialTensorShape& shp,
+                                     Tensor* tensor) {
   if (op == "Shape" || op == "ShapeN") {
-    *value = Tensor(type, TensorShape({shp.dims()}));
+    *tensor = Tensor(type, TensorShape({shp.dims()}));
     for (int i = 0; i < shp.dims(); ++i) {
-      if (type == DT_INT32) {
-        if (shp.dim_size(i) >= INT_MAX) {
-          return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
-        }
-        value->flat<int32>()(i) = shp.dim_size(i);
-      } else {
-        value->flat<int64>()(i) = shp.dim_size(i);
-      }
+      TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
     }
   } else if (op == "Size") {
     int64 size = 1;
     for (int i = 0; i < shp.dims(); ++i) {
       size *= shp.dim_size(i);
     }
-    *value = Tensor(type, TensorShape({}));
-    if (type == DT_INT32) {
-      if (size >= INT_MAX) {
-        return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
-      }
-      value->flat<int32>()(0) = size;
-    } else {
-      value->flat<int64>()(0) = size;
-    }
+    *tensor = Tensor(type, TensorShape({}));
+    TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
   } else {
-    *value = Tensor(type, TensorShape({}));
-    if (type == DT_INT32) {
-      if (shp.dims() >= INT_MAX) {
-        return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
-      }
-      value->flat<int32>()(0) = shp.dims();
-    } else {
-      value->flat<int64>()(0) = shp.dims();
-    }
+    CHECK_EQ(op, "Rank");
+    *tensor = Tensor(type, TensorShape({}));
+    TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
   }
   return Status::OK();
 }
@@ -306,13 +303,14 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
   return feed_nodes_.find(node.name()) == feed_nodes_.end();
 }
 
+// Materialize the shapes using constants whenever possible.
 Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
-  // We may add some nodes to the graph to encode control dependencies: there is
-  // no need to process these, so only iterate over the nodes of the input
-  // graph.
+  // We may add some nodes to the graph to encode control dependencies and hold
+  // the materialized shapes: there is no need to process these added nodes, so
+  // only iterate over the nodes of the input graph.
   const int node_count = graph_->node_size();
-  for (int i = 0; i < node_count; ++i) {
-    NodeDef* node = graph_->mutable_node(i);
+  for (int node_idx = 0; node_idx < node_count; ++node_idx) {
+    NodeDef* node = graph_->mutable_node(node_idx);
     const string op = node->op();
     if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
       continue;
@@ -325,91 +323,109 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
     if (input.empty() || output.empty()) {
       continue;
     }
+
     if (op == "Shape" || op == "Size" || op == "Rank") {
       CHECK_EQ(1, output.size());
       CHECK_EQ(1, input.size());
+
+      const DataType type = output[0].dtype();
+      CHECK(type == DT_INT32 || type == DT_INT64);
+      const PartialTensorShape shape(input[0].shape());
+
+      if ((op != "Rank" && !shape.IsFullyDefined()) ||
+          (op == "Rank" && shape.unknown_rank())) {
+        continue;
+      }
+
+      Tensor constant_value(type);
+      if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
+        continue;
+      }
+
+      // Repurpose the existing node to be the constant.
+      // Device placement is preserved.
+      node->set_op("Const");
+      node->clear_attr();
+      (*node->mutable_attr())["dtype"].set_type(type);
+      constant_value.AsProtoTensorContent(
+          (*node->mutable_attr())["value"].mutable_tensor());
+
+      // Turn the data input into a control dependency: this is needed to
+      // ensure that the constant value will only be run in the
+      // cases where the shape/rank/size would have been run in
+      // the original graph.
+      string ctrl_dep =
+          AddControlDependency(node->input(0), graph_, node_map_.get());
+      node->set_input(0, ctrl_dep);
+      node_map_->AddOutput(NodeName(ctrl_dep), node->name());
+
+      // Done with the Shape/Size/Rank node, move to the next node.
+      continue;
     }
-    CHECK_EQ(input.size(), output.size());
 
-    for (int j = 0; j < output.size(); ++j) {
-      const DataType type = output[j].dtype();
+    // Handle ShapeN materialization case.
+    // It's possible that not all input tensors have known shapes.
+    CHECK_EQ(op, "ShapeN");
+    CHECK_EQ(input.size(), output.size());
+    const NodeDef* const shape_n_node = node;
+    for (int port_idx = 0; port_idx < output.size(); ++port_idx) {
+      const DataType type = output[port_idx].dtype();
       CHECK(type == DT_INT32 || type == DT_INT64);
-      const TensorShapeProto shape = input[j].shape();
-      // Materialize the shapes using constants whenever possible.
-      PartialTensorShape shp(shape);
-      if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
-        Tensor value(type);
-        auto status = ConvertShapeToConstant(op, type, shp, &value);
-        if (!status.ok()) {
-          continue;
-        }
-        // We rewrite the existing node for the first const output and
-        // create new nodes for the remaining const outputs (Note that ShapeN
-        // could have multiple outputs).
-        if (op == "Shape" || op == "Size" || op == "Rank") {
-          // Replace the node with the corresponding constant.
-          node->set_op("Const");
-          node->clear_attr();
-          (*node->mutable_attr())["dtype"].set_type(type);
-          value.AsProtoTensorContent(
-              (*node->mutable_attr())["value"].mutable_tensor());
-
-          // Turn the data input into a control dependency: this is needed to
-          // ensure that the constant value will only be run in the
-          // cases where the shape/rank/size would have been run in
-          // the original graph. Additional inputs are extra control
-          string ctrl_dep =
-              AddControlDependency(node->input(0), graph_, node_map_.get());
-          node->set_input(0, ctrl_dep);
-          node_map_->AddOutput(NodeName(ctrl_dep), node->name());
-        } else {
-          auto outputs = node_map_->GetOutputs(node->name());
-          for (NodeDef* output : outputs) {
-            for (int k = 0; k < output->input_size(); ++k) {
-              int port;
-              string node_name = ParseNodeName(output->input(k), &port);
-              if (node_name == node->name() && port == j) {
-                // Create a const node as ShapeN's output if not already.
-                const string const_name =
-                    OptimizedNodeName(*node, strings::StrCat("-matshapes-", j));
-                if (node_map_->GetNode(const_name) == nullptr) {
-                  NodeDef* added_node = graph_->add_node();
-                  added_node->set_name(const_name);
-                  added_node->set_op("Const");
-                  added_node->set_device(node->device());
-                  node_map_->AddNode(added_node->name(), added_node);
-                  (*added_node->mutable_attr())["dtype"].set_type(type);
-                  value.AsProtoTensorContent(
-                      (*added_node->mutable_attr())["value"].mutable_tensor());
-                  // We add a control dependency to the original ShapeN node,
-                  // so that the node will only be run if all inputs of the
-                  // original ShapeN node are run.
-                  string ctrl_dep = AddControlDependency(node->name(), graph_,
-                                                         node_map_.get());
-                  *added_node->add_input() = ctrl_dep;
-                  node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
-                }
-                *output->mutable_input(k) = const_name;
-                node_map_->AddOutput(const_name, output->name());
-              }
-            }
-            bool remove_output = true;
-            for (int k = 0; k < output->input_size(); ++k) {
-              int port;
-              string node_name = ParseNodeName(output->input(k), &port);
-              if (node_name == node->name()) {
-                remove_output = false;
-                break;
-              }
-            }
-            if (remove_output) {
-              node_map_->RemoveOutput(node->name(), output->name());
+      const PartialTensorShape shape(input[port_idx].shape());
+      if (!shape.IsFullyDefined()) {
+        continue;
+      }
+      Tensor constant_value(type);
+      auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
+      if (!status.ok()) {
+        continue;
+      }
+
+      // Find all nodes consuming this shape and connect them through the new
+      // constant node instead.
+      auto outputs = node_map_->GetOutputs(shape_n_node->name());
+      for (NodeDef* output : outputs) {
+        // Track whether there are any direct edges left between shape_n_node
+        // and this output node after the transformation.
+        bool direct_edges_exist = false;
+        for (int k = 0; k < output->input_size(); ++k) {
+          int port;
+          const string node_name = ParseNodeName(output->input(k), &port);
+          if (node_name == shape_n_node->name() && port == port_idx) {
+            // Create a const node as ShapeN's output if not already.
+            const string const_name = OptimizedNodeName(
+                *shape_n_node, strings::StrCat("-matshapes-", port_idx));
+            if (node_map_->GetNode(const_name) == nullptr) {
+              NodeDef* added_node = graph_->add_node();
+              added_node->set_name(const_name);
+              added_node->set_op("Const");
+              added_node->set_device(shape_n_node->device());
+              node_map_->AddNode(added_node->name(), added_node);
+              (*added_node->mutable_attr())["dtype"].set_type(type);
+              constant_value.AsProtoTensorContent(
+                  (*added_node->mutable_attr())["value"].mutable_tensor());
+              // We add a control dependency to the original ShapeN node,
+              // so that the node will only be run if all inputs of the
+              // original ShapeN node are run.
+              string ctrl_dep = AddControlDependency(shape_n_node->name(),
+                                                     graph_, node_map_.get());
+              *added_node->add_input() = ctrl_dep;
+              node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
             }
+            *output->mutable_input(k) = const_name;
+            node_map_->AddOutput(const_name, output->name());
           }
+          if (node_name == shape_n_node->name() && port != port_idx) {
+            direct_edges_exist = true;
+          }
+        }
+        if (!direct_edges_exist) {
+          node_map_->RemoveOutput(node->name(), output->name());
         }
       }
     }
   }
+
   return Status::OK();
 }