From fc09f65a5d283baa9af182536e3e3652c7a41dd7 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Sun, 4 Feb 2018 11:24:51 -0800 Subject: [PATCH] Avoid retaining two copies of each constant in `ConstantOp`. Presently, the kernel keeps two copies of the constant tensor value, which can be large: 1. In the `ConstantOp::tensor_` field. 2. In the `OpKernel::def_` field (as an attr of the `NodeDef`). Since we can be sure that `ConstantOp` will never need to access the tensor value from `OpKernel::def_`, this change introduces a mechanism for `OpKernel` implementations to store a stripped `NodeDef` in the base class, and uses it in `ConstantOp` to avoid storing the tensor value attr. PiperOrigin-RevId: 184455793 --- tensorflow/core/framework/op_kernel.cc | 8 +++++++- tensorflow/core/framework/op_kernel.h | 8 ++++++++ tensorflow/core/kernels/constant_op.cc | 28 +++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index fd2d06b..56c013d 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -79,8 +79,14 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, // OpKernel ------------------------------------------------------------------ +// TODO(mrry): Convert to std::make_unique when available. OpKernel::OpKernel(OpKernelConstruction* context) - : def_(new NodeDef(context->def())), + : OpKernel(context, + std::unique_ptr(new NodeDef(context->def()))) {} + +OpKernel::OpKernel(OpKernelConstruction* context, + std::unique_ptr node_def) + : def_(std::move(node_def)), input_types_(context->input_types().begin(), context->input_types().end()), input_memory_types_(context->input_memory_types().begin(), diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index b72f140..a3dc96b 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -75,6 +75,14 @@ class OpKernel { // OpKernel won't be instantiated by the scheduler, so you may perform // expensive initialization in the descendant's constructor. explicit OpKernel(OpKernelConstruction* context); + + // Specialized constructor that enables the descendant to provide a different + // `NodeDef` value. For example, this constructor can be used to provide a + // stripped-down `NodeDef` that does not contain the full set of attrs (such + // as tensor values) if the descendant stores them in a different form. + explicit OpKernel(OpKernelConstruction* context, + std::unique_ptr node_def); + virtual ~OpKernel(); // An OpKernel's computation can be either synchronous or diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 920cd87..4ab6fdb 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/kernels/constant_op.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -41,8 +42,33 @@ limitations under the License. namespace tensorflow { +namespace { + +std::unique_ptr StripTensorDataFromNodeDef( + OpKernelConstruction* ctx) { +#ifndef __ANDROID__ + DCHECK_EQ(NodeDef::descriptor()->field_count(), 5) + << "The NodeDef format has changed, and the attr-stripping code may need " + << "to be updated."; +#endif + const NodeDef& original = ctx->def(); + NodeDef* ret = new NodeDef; + ret->set_name(original.name()); + ret->set_op(original.op()); + ret->set_device(original.device()); + // Strip the "value" attr from the returned NodeDef. + // NOTE(mrry): The present implementation of `OpKernel::OpKernel()` only uses + // attrs that affect the cardinality of list-typed inputs and outputs, so it + // is safe to drop other attrs from the NodeDef. + AddNodeAttr("dtype", ctx->output_type(0), ret); + return std::unique_ptr(ret); +} + +} // namespace + ConstantOp::ConstantOp(OpKernelConstruction* ctx) - : OpKernel(ctx), tensor_(ctx->output_type(0)) { + : OpKernel(ctx, StripTensorDataFromNodeDef(ctx)), + tensor_(ctx->output_type(0)) { const TensorProto* proto = nullptr; OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto( -- 2.7.4