Avoid retaining two copies of each constant in `ConstantOp`.
authorDerek Murray <mrry@google.com>
Sun, 4 Feb 2018 19:24:51 +0000 (11:24 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 4 Feb 2018 19:28:47 +0000 (11:28 -0800)
Presently, the kernel keeps two copies of the constant tensor value, which can be large:
1. In the `ConstantOp::tensor_` field.
2. In the `OpKernel::def_` field (as an attr of the `NodeDef`).

Since we can be sure that `ConstantOp` will never need to access the
tensor value from `OpKernel::def_`, this change introduces a mechanism
for `OpKernel` implementations to store a stripped `NodeDef` in the
base class, and uses it in `ConstantOp` to avoid storing the tensor
value attr.

PiperOrigin-RevId: 184455793

tensorflow/core/framework/op_kernel.cc
tensorflow/core/framework/op_kernel.h
tensorflow/core/kernels/constant_op.cc

index fd2d06b..56c013d 100644 (file)
@@ -79,8 +79,14 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
 
 // OpKernel ------------------------------------------------------------------
 
+// TODO(mrry): Convert to std::make_unique when available.
 OpKernel::OpKernel(OpKernelConstruction* context)
-    : def_(new NodeDef(context->def())),
+    : OpKernel(context,
+               std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {}
+
+OpKernel::OpKernel(OpKernelConstruction* context,
+                   std::unique_ptr<const NodeDef> node_def)
+    : def_(std::move(node_def)),
       input_types_(context->input_types().begin(),
                    context->input_types().end()),
       input_memory_types_(context->input_memory_types().begin(),
index b72f140..a3dc96b 100644 (file)
@@ -75,6 +75,14 @@ class OpKernel {
   // OpKernel won't be instantiated by the scheduler, so you may perform
   // expensive initialization in the descendant's constructor.
   explicit OpKernel(OpKernelConstruction* context);
+
+  // Specialized constructor that enables the descendant to provide a different
+  // `NodeDef` value. For example, this constructor can be used to provide a
+  // stripped-down `NodeDef` that does not contain the full set of attrs (such
+  // as tensor values) if the descendant stores them in a different form.
+  explicit OpKernel(OpKernelConstruction* context,
+                    std::unique_ptr<const NodeDef> node_def);
+
   virtual ~OpKernel();
 
   // An OpKernel's computation can be either synchronous or
index 920cd87..4ab6fdb 100644 (file)
@@ -24,6 +24,7 @@ limitations under the License.
 #include "tensorflow/core/kernels/constant_op.h"
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor.pb.h"
@@ -41,8 +42,33 @@ limitations under the License.
 
 namespace tensorflow {
 
+namespace {
+
+std::unique_ptr<const NodeDef> StripTensorDataFromNodeDef(
+    OpKernelConstruction* ctx) {
+#ifndef __ANDROID__
+  DCHECK_EQ(NodeDef::descriptor()->field_count(), 5)
+      << "The NodeDef format has changed, and the attr-stripping code may need "
+      << "to be updated.";
+#endif
+  const NodeDef& original = ctx->def();
+  NodeDef* ret = new NodeDef;
+  ret->set_name(original.name());
+  ret->set_op(original.op());
+  ret->set_device(original.device());
+  // Strip the "value" attr from the returned NodeDef.
+  // NOTE(mrry): The present implementation of `OpKernel::OpKernel()` only uses
+  // attrs that affect the cardinality of list-typed inputs and outputs, so it
+  // is safe to drop other attrs from the NodeDef.
+  AddNodeAttr("dtype", ctx->output_type(0), ret);
+  return std::unique_ptr<const NodeDef>(ret);
+}
+
+}  // namespace
+
 ConstantOp::ConstantOp(OpKernelConstruction* ctx)
-    : OpKernel(ctx), tensor_(ctx->output_type(0)) {
+    : OpKernel(ctx, StripTensorDataFromNodeDef(ctx)),
+      tensor_(ctx->output_type(0)) {
   const TensorProto* proto = nullptr;
   OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
   OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto(