Avoid creating large constants since protocol buffers are limited to 2GB in size.
authorBenoit Steiner <bsteiner@google.com>
Thu, 22 Feb 2018 06:02:54 +0000 (22:02 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Feb 2018 06:06:54 +0000 (22:06 -0800)
PiperOrigin-RevId: 186567461

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index c455f28..fbb3e5a 100644 (file)
@@ -870,8 +870,13 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
       }
       TensorValue value(&t);
       NodeDef* new_const_node = AddNode(*node, "const", /*copy_node=*/false);
-      *new_const_node =
-          ConstantFolding::CreateNodeDef(new_const_node->name(), value);
+      status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
+                                              new_const_node);
+      if (!status.ok()) {
+        LOG(WARNING) << "Failed to create const node: "
+                     << status.error_message();
+        return "";
+      }
       new_const_node->set_device(node->device());
       nodes_to_simplify->PushBack(new_const_node);
 
index 95eaa31..064cb8b 100644 (file)
@@ -529,7 +529,8 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
     out[j] = node_map_->GetNode(const_name);
     if (out[j] == nullptr) {
       out[j] = graph_->add_node();
-      *out[j] = CreateNodeDef(const_name, TensorValue(&value));
+      TF_RETURN_IF_ERROR(
+          CreateNodeDef(const_name, TensorValue(&value), out[j]));
       out[j]->set_device(node.device());
       node_map_->AddNode(const_name, out[j]);
       string ctrl_dep =
@@ -637,7 +638,8 @@ Status ConstantFolding::MaterializeReductionIndices(
       value.vec<int64>()(i) = i;
     }
   }
-  *reduction_indices = CreateNodeDef(const_name, TensorValue(&value));
+  TF_RETURN_IF_ERROR(
+      CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
   reduction_indices->set_device(node->device());
   string ctrl_dep =
       AddControlDependency(node->input(1), graph_, node_map_.get());
@@ -792,19 +794,20 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
 }  // namespace
 
 // static
-NodeDef ConstantFolding::CreateNodeDef(const string& name,
-                                       const TensorValue& tensor) {
-  NodeDef node;
-  node.set_name(name);
-  node.set_op("Const");
+Status ConstantFolding::CreateNodeDef(const string& name,
+                                      const TensorValue& tensor,
+                                      NodeDef* node) {
+  node->set_name(name);
+  node->set_op("Const");
 
   AttrValue attr_type;
   attr_type.set_type(tensor->dtype());
-  node.mutable_attr()->insert({"dtype", attr_type});
+  node->mutable_attr()->insert({"dtype", attr_type});
 
   AttrValue attr_tensor;
   TensorProto* t = attr_tensor.mutable_tensor();
   bool optimized = false;
+  size_t encoded_size;
   // Use the packed representation whenever possible to avoid generating large
   // graphdefs. Moreover, avoid repeating the last values if they're equal.
   if (tensor->NumElements() > 4) {
@@ -821,6 +824,7 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name,
   }                                                                 \
   if (last_index < kint32max) {                                     \
     optimized = true;                                               \
+    encoded_size = (last_index + 1) * sizeof(NAME);                 \
     t->mutable_##NAME##_val()->Reserve(last_index + 1);             \
     t->mutable_##NAME##_val()->AddNAlreadyReserved(last_index + 1); \
     val_ptr = tensor->flat<TYPE>().data();                          \
@@ -853,9 +857,15 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name,
     tensor->shape().AsProto(t->mutable_tensor_shape());
   } else {
     tensor->AsProtoTensorContent(t);
+    encoded_size = t->tensor_content().size();
+  }
+  node->mutable_attr()->insert({"value", attr_tensor});
+
+  if (encoded_size < 10 * 1024 * 1024) {
+    return Status::OK();
   }
-  node.mutable_attr()->insert({"value", attr_tensor});
-  return node;
+  return errors::InvalidArgument(
+      strings::StrCat("Can't fold ", name, ", its size would be too large"));
 }
 
 Status ConstantFolding::EvaluateNode(const NodeDef& node,
@@ -929,17 +939,19 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
     return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
   }
 
+  outputs->resize(output_tensors.size());
   for (size_t i = 0; i < output_tensors.size(); i++) {
     string node_name = OptimizedNodeName(node, "-folded");
     if (output_tensors.size() > 1) {
       node_name = strings::StrCat(node_name, "-", i);
     }
     if (output_tensors[i].tensor) {
-      outputs->push_back(CreateNodeDef(node_name, output_tensors[i]));
+      TF_RETURN_IF_ERROR(
+          CreateNodeDef(node_name, output_tensors[i], &outputs->at(i)));
     } else {
       // Create an empty NodeDef to identify dead outputs (e.g. the output of a
       // switch that's not selected by the switch predicate).
-      outputs->push_back(NodeDef());
+      outputs->at(i) = NodeDef();
     }
   }
   return Status::OK();
index e407851..232b2f9 100644 (file)
@@ -33,7 +33,8 @@ const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
 // Constant folding optimization for a graph.
 class ConstantFolding : public GraphOptimizer {
  public:
-  static NodeDef CreateNodeDef(const string& name, const TensorValue& tensor);
+  static Status CreateNodeDef(const string& name, const TensorValue& tensor,
+                              NodeDef* node);
   static string AddControlDependency(const string& input_name, GraphDef* graph,
                                      NodeMap* node_map);