// OpKernel ------------------------------------------------------------------
+// TODO(mrry): Convert to std::make_unique when available.
OpKernel::OpKernel(OpKernelConstruction* context)
- : def_(new NodeDef(context->def())),
+ : OpKernel(context,
+ std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {}
+
+OpKernel::OpKernel(OpKernelConstruction* context,
+ std::unique_ptr<const NodeDef> node_def)
+ : def_(std::move(node_def)),
input_types_(context->input_types().begin(),
context->input_types().end()),
input_memory_types_(context->input_memory_types().begin(),
// OpKernel won't be instantiated by the scheduler, so you may perform
// expensive initialization in the descendant's constructor.
explicit OpKernel(OpKernelConstruction* context);
+
+ // Specialized constructor that enables the descendant to provide a different
+ // `NodeDef` value. For example, this constructor can be used to provide a
+ // stripped-down `NodeDef` that does not contain the full set of attrs (such
+ // as tensor values) if the descendant stores them in a different form.
+ explicit OpKernel(OpKernelConstruction* context,
+ std::unique_ptr<const NodeDef> node_def);
+
virtual ~OpKernel();
// An OpKernel's computation can be either synchronous or
#include "tensorflow/core/kernels/constant_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
+namespace {
+
+std::unique_ptr<const NodeDef> StripTensorDataFromNodeDef(
+ OpKernelConstruction* ctx) {
+#ifndef __ANDROID__
+ DCHECK_EQ(NodeDef::descriptor()->field_count(), 5)
+ << "The NodeDef format has changed, and the attr-stripping code may need "
+ << "to be updated.";
+#endif
+ const NodeDef& original = ctx->def();
+ NodeDef* ret = new NodeDef;
+ ret->set_name(original.name());
+ ret->set_op(original.op());
+ ret->set_device(original.device());
+ // Strip the "value" attr from the returned NodeDef.
+ // NOTE(mrry): The present implementation of `OpKernel::OpKernel()` only uses
+ // attrs that affect the cardinality of list-typed inputs and outputs, so it
+ // is safe to drop other attrs from the NodeDef.
+ AddNodeAttr("dtype", ctx->output_type(0), ret);
+ return std::unique_ptr<const NodeDef>(ret);
+}
+
+} // namespace
+
ConstantOp::ConstantOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), tensor_(ctx->output_type(0)) {
+ : OpKernel(ctx, StripTensorDataFromNodeDef(ctx)),
+ tensor_(ctx->output_type(0)) {
const TensorProto* proto = nullptr;
OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto(