From 245f6c840c86f79a0e42d81ca84b8ce5eeb4c396 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Fri, 2 Feb 2018 17:09:59 -0800 Subject: [PATCH] [tf.data] Add public header "tensorflow/core/framework/dataset.h". This adds the ability to create a custom C++ Dataset implementation without linking it statically into the TensorFlow library. Note that this internal API is experimental and subject to change between versions of TensorFlow. Fixes #16682. PiperOrigin-RevId: 184356318 --- tensorflow/core/BUILD | 1 + tensorflow/core/framework/dataset.h | 614 +++++++++++++++++++++++++++++++++ tensorflow/core/kernels/data/dataset.h | 596 +------------------------------- 3 files changed, 616 insertions(+), 595 deletions(-) create mode 100644 tensorflow/core/framework/dataset.h diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 68c82ea..3aa3018 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -432,6 +432,7 @@ tf_cuda_library( "framework/cancellation.h", "framework/common_shape_fns.h", "framework/control_flow.h", # TODO(josh11b): Make internal? + "framework/dataset.h", "framework/dataset_stateful_op_whitelist.h", "framework/device_base.h", "framework/function.h", diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h new file mode 100644 index 0000000..96566c2 --- /dev/null +++ b/tensorflow/core/framework/dataset.h @@ -0,0 +1,614 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ + +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/tracing.h" + +// Polymorphic datasets should support all primitive TensorFlow +// types. Use this macro to expand `m(T)` once for each primitive type +// `T`, e.g. to build a `switch` statement. +#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) + +namespace tensorflow { + +// Interface for reading values from a key-value store. +// Used for restoring iterator state. +class IteratorStateReader { + public: + virtual Status ReadScalar(StringPiece key, int64* val) = 0; + virtual Status ReadScalar(StringPiece key, string* val) = 0; + virtual Status ReadTensor(StringPiece key, Tensor* val) = 0; + virtual bool Contains(StringPiece key) = 0; + + virtual ~IteratorStateReader() {} +}; + +// Interface for writing values to a key-value store. +// Used for saving iterator state. +class IteratorStateWriter { + public: + virtual Status WriteScalar(StringPiece key, const int64 val) = 0; + virtual Status WriteScalar(StringPiece key, const string& val) = 0; + virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; + + virtual ~IteratorStateWriter() {} +}; + +// Forward declarations to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; +class GraphDatasetBase; +class Node; + +// Wrapper around GraphDefBuilder. Used to serialize Dataset graph. +class GraphDefBuilderWrapper { + public: + explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} + + // Adds a Const node with scalar value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + template + Status AddScalar(const T& val, Node** output) { + Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); + val_t.scalar()() = val; + AddTensorInternal(val_t, output); + if (*output == nullptr) { + return errors::Internal("AddScalar: Failed to build Const op."); + } + return Status::OK(); + } + + // Adds a Const node with vector value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? + template + Status AddVector(const std::vector& val, Node** output) { + Tensor val_t = Tensor(DataTypeToEnum::v(), + TensorShape({static_cast(val.size())})); + for (int i = 0; i < val.size(); i++) { + val_t.flat()(i) = val[i]; + } + AddTensorInternal(val_t, output); + if (*output == nullptr) { + return errors::Internal("AddVector: Failed to build Const op."); + } + return Status::OK(); + } + + // Adds a Const node with Tensor value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + Status AddTensor(const Tensor& val, Node** output) { + AddTensorInternal(val, output); + if (*output == nullptr) { + return errors::Internal("AddTensor: Failed to build Const op."); + } + return Status::OK(); + } + + Status AddDataset(const GraphDatasetBase* dataset, + const std::vector& inputs, Node** output) { + return AddDataset(dataset, inputs, {}, output); + } + + // Adds a node corresponding to the `DatasetType` to the Graph. + // Return value of `DatasetType::op_name()` is used as the op type for the + // node. + // Values for the output_types and output_shapes node attributes are also + // written if those attributes are defined in the OpDef. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + Status AddDataset(const GraphDatasetBase* dataset, + const std::vector& inputs, + const std::vector>& attrs, + Node** output) { + std::vector> enumerated_inputs(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + enumerated_inputs[i] = std::make_pair(i, inputs[i]); + } + return AddDataset(dataset, enumerated_inputs, {}, attrs, output); + } + + Status AddDataset( + const GraphDatasetBase* dataset, + const std::vector>& inputs, + const std::vector>>& list_inputs, + const std::vector>& attrs, + Node** output); + + // Adds a user-defined function with name `function_name` to the graph and + // recursively adds all functions it references. If a function with a matching + // name has already been added, returns with OK status. If a user-defined with + // name `function_name` is not found in the FunctionLibraryDefinition, returns + // an InvalidArgumentError. If the function with name `function_name` or any + // of its dependent functions are stateful, returns an InvalidArgument error. + Status AddFunction(OpKernelContext* ctx, const string& function_name); + + template + void BuildAttrValue(const T& value, AttrValue* attr) { + SetAttrValue(value, attr); + } + + private: + void AddTensorInternal(const Tensor& val, Node** output); + + Status EnsureFunctionIsStateless(OpKernelContext* ctx, + const string& function_name) const { + const FunctionLibraryDefinition* lib_def = + ctx->function_library()->GetFunctionLibraryDefinition(); + const FunctionDef* function_def = lib_def->Find(function_name); + if (!function_def) { + return errors::InvalidArgument("Unable to find FunctionDef for ", + function_name, " in registry."); + } + for (const NodeDef& node_def : function_def->node_def()) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def)); + // TODO(b/65524810): Hack to allow functions to capture Dataset op + // nodes needed for FlatMap. Currently, source datasets nodes have been + // marked stateful to avoid constant folding since we do not have a + // good way of serializing them. + if (IsOpWhitelisted(op_def)) { + continue; + } + if (op_def->is_stateful()) { + return errors::InvalidArgument( + "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ", + "in function ", function_name, " is stateful. ", + "Saving stateful functions is not supported yet."); + } + } + return Status::OK(); + } + + // Returns whether an op has been whitelisted for use inside map_fns. + // Uses a heuristic to whitelist source dataset ops which have been + // marked stateful due to b/65524810. + // Also looks up the `op_def->name` in the global + // `WhitelistedStatefulOpRegistry`. + bool IsOpWhitelisted(const OpDef* op_def) const { + return (StringPiece(op_def->name()).ends_with("Dataset") && + op_def->output_arg_size() == 1 && + op_def->output_arg(0).type() == DT_VARIANT) || + dataset::WhitelistedStatefulOpRegistry::Global()->Contains( + op_def->name()); + } + + bool HasAttr(const string& op_type_name, const string& attr_name) const; + + bool HasAttr(const OpDef* op_def, const string& attr_name) const { + for (auto attr : op_def->attr()) { + if (attr.name() == attr_name) { + return true; + } + } + return false; + } + + Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) { + if (attr_value.has_func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); + } else if (attr_value.has_list()) { + for (const NameAttrList& name_attr_list : attr_value.list().func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); + } + } + return Status::OK(); + } + + GraphDefBuilder* b_; +}; + +class StatsAggregator; + +// A cut-down version of OpKernelContext for running computations in +// iterators. Note that we cannot simply use OpKernelContext here +// because we might run computation in an iterator whose lifetime is +// not nested within the lifetime of a single OpKernelContext +// (e.g. asynchronous prefetching). +// +// TODO(mrry): We will probably need to support more of +// OpKernelContext here. For example, should allocation be handled by +// the IteratorContext? +// TODO(mrry): We're making some daring assumptions about the lifetime +// of the runner passed in here. A runner will be deleted when the original +// step ends, but all existing runners only close over session-lifetime (or +// longer-lived) state, so we can make a copy of the function. There's nothing +// in the definition of the API from which we took the runner to guarantee that +// what we are doing is safe. We should formalize the properties here. +class IteratorContext { + public: + struct Params { + // Interface to operating system functionality. + Env* env; + + // Function call support. + std::function)> runner = nullptr; + + // A function that returns the current `StatsAggregator` instance to be + // used when recording statistics about the iterator. + // + // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator` + // is a property of the `IteratorResource` (which this class does not know + // about), and (ii) it can change after the `IteratorContext` has been + // created. Better suggestions are welcome! + std::function()> stats_aggregator_getter = + nullptr; + + // The FunctionLibraryRuntime object to be used to make function calls. + FunctionLibraryRuntime* lib = nullptr; + std::shared_ptr function_library = nullptr; + + // The Allocator to be used to allocate the output of an iterator. + Allocator* allocator = nullptr; + }; + + explicit IteratorContext(Params params) : params_(std::move(params)) {} + + Env* env() const { return params_.env; } + + std::function)>* runner() { + return ¶ms_.runner; + } + + std::shared_ptr stats_aggregator() { + if (params_.stats_aggregator_getter) { + return params_.stats_aggregator_getter(); + } else { + return nullptr; + } + } + + std::shared_ptr function_library() { + return params_.function_library; + } + + FunctionLibraryRuntime* lib() { return params_.lib; } + + void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; } + + Allocator* allocator(AllocatorAttributes attrs); + + private: + Params params_; +}; + +// Represents the current position in a range of outputs, where the +// range of outputs is typically represented by an `DatasetBase`, +// defined below. +class IteratorBase { + public: + virtual ~IteratorBase() {} + + // Gets the next output from the range that this iterator is traversing. + // + // If at least one output remains in this iterator's range, that + // output will be stored in `*out_tensors` and `false` will be + // stored in `*end_of_sequence`. + // + // If no more outputs remain in this iterator's range, `true` will + // be stored in `*end_of_sequence`, and the content of + // `*out_tensors` will be undefined. + // + // This method is thread-safe. + // + // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and + // potentially remove this method. + virtual Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) = 0; + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this + // iterator. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this iterator. + virtual const std::vector& output_shapes() const = 0; + + // Saves the state of this iterator. + virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { + return SaveInternal(writer); + } + + // Restores the state of this iterator. + virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) { + return RestoreInternal(ctx, reader); + } + + protected: + // This is needed so that sub-classes of IteratorBase can call + // `SaveInternal` on their parent iterators, e.g., in + // `RepeatDataasetOp::Dataset`. + Status SaveParent(IteratorStateWriter* writer, + const std::unique_ptr& parent) { + return parent->SaveInternal(writer); + } + + // This is needed so that sub-classes of IteratorBase can call + // `RestoreInternal` on their parent iterators, e.g., in + // `RepeatDataasetOp::Dataset`. + Status RestoreParent(IteratorContext* ctx, IteratorStateReader* reader, + const std::unique_ptr& parent) { + return parent->RestoreInternal(ctx, reader); + } + + // Saves the state of this iterator recursively. + virtual Status SaveInternal(IteratorStateWriter* writer) { + return errors::Unimplemented("SaveInternal"); + } + + // Restores the state of this iterator recursively. + virtual Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) { + return errors::Unimplemented("RestoreInternal"); + } +}; + +// Represents a (potentially infinite) range of outputs, where each +// output is a tuple of tensors. +class DatasetBase : public core::RefCounted { + public: + // Returns a new iterator for iterating over the range of elements in + // this dataset. + // + // This method may be called multiple times on the same instance, + // and the resulting iterators will have distinct state. Each + // iterator will traverse all elements in this dataset from the + // start. + // + // Ownership of the created iterator will be transferred to the caller. + // + // The prefix identifies the sequence of iterators leading up to the newly + // created iterator. + virtual std::unique_ptr MakeIterator( + const string& prefix) const = 0; + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this + // dataset. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this dataset. + virtual const std::vector& output_shapes() const = 0; + + // A human-readable debug string for this dataset. + virtual string DebugString() = 0; + + // Serializes the dataset and writes it to the `writer`. + virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { + return errors::Unimplemented("DatasetBase::Save"); + } + + protected: + // TODO(srbs): Ideally all graph related logic should reside in + // GraphDatasetBase. However, that would require Datasets defined in all ops + // to derive from GraphDatasetBase. Once that is done we can move + // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase. + class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { + public: + DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} + Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset, + Node** output) { + return dataset->AsGraphDefInternal(ctx, this, output); + } + }; + + virtual Status AsGraphDefInternal(OpKernelContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const { + return AsGraphDefInternal(b, node); + } + + virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** node) const { + return errors::Unimplemented("AsGraphDefInternal"); + } +}; + +// Base-class for datasets that are built by ops. +class GraphDatasetBase : public DatasetBase { + public: + GraphDatasetBase(OpKernelContext* ctx) + : op_name_(ctx->op_kernel().type_string()) {} + + const string op_name() const { return op_name_; } + + Status Save(OpKernelContext* ctx, + IteratorStateWriter* writer) const override { + string serialized_graph_def; + string output_node; + TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); + return Status::OK(); + } + + // Key for storing the Dataset graph in the serialized format. + static const char kDatasetGraphKey[]; + + // Key for storing the output node of the Dataset graph in the serialized + // format. + static const char kDatasetGraphOutputNodeKey[]; + + private: + Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, + string* output_node) const; + + const string op_name_; +}; + +// Represents an iterator that is associated with a particular parent dataset. +template +class DatasetIterator : public IteratorBase { + public: + struct Params { + // Owns one reference on the shared dataset resource. + const DatasetType* dataset; + + // Identifies the sequence of iterators leading up to this iterator. + const string prefix; + }; + + explicit DatasetIterator(const Params& params) : params_(params) { + params_.dataset->Ref(); + } + + ~DatasetIterator() override { params_.dataset->Unref(); } + + // The dataset from which this iterator was created. + const DatasetType* dataset() const { return params_.dataset; } + + // The sequence of iterators leading up to this iterator. + const string prefix() const { return params_.prefix; } + + const DataTypeVector& output_dtypes() const override { + return params_.dataset->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return params_.dataset->output_shapes(); + } + + Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) final { + port::Tracing::TraceMe activity(params_.prefix); + Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); + if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { + s = errors::Internal( + "Iterator \"", params_.prefix, + "\" returned OutOfRange without setting `*end_of_sequence`. This " + "indicates that an error may have occurred. Original message: ", + s.error_message()); + LOG(ERROR) << s; + } + return s; + } + + Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final { + TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer)); + return IteratorBase::Save(ctx, writer); + } + + protected: + // Internal implementation of GetNext that is wrapped in tracing logic. + virtual Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) = 0; + + string full_name(const string& name) const { + return strings::StrCat(prefix(), ":", name); + } + + private: + Params params_; +}; + +// Encapsulates the work required to plug a DatasetBase into the core TensorFlow +// graph execution engine. +class DatasetOpKernel : public OpKernel { + public: + DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) final; + + protected: + // Subclasses should implement this method. It will be called during Compute + // execution. + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; + + template + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); + } +}; + +// Encapsulates the work required to plug unary Datasets into the core +// TensorFlow graph execution engine. +class UnaryDatasetOpKernel : public DatasetOpKernel { + public: + UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) = 0; +}; + +// Encapsulates the work required to plug binary Datasets into the core +// TensorFlow graph execution engine. +class BinaryDatasetOpKernel : public DatasetOpKernel { + public: + BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase* another_input, + DatasetBase** output) = 0; +}; + +// Validates and extracts a `DatasetBase` object from `tensor`. +// +// `tensor` must have been written by a call to SetVariantTensorToDataset(). +// +// The retrieved pointer is a borrowed reference to the dataset, which is owned +// by the tensor. The consumer must either acquire its own reference to the +// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not +// destroyed or mutated while the retrieved pointer is in use. +Status GetDatasetFromVariantTensor(const Tensor& tensor, + DatasetBase** out_dataset); + +// Stores a `DatasetBase` object in `tensor`. +// +// The ownership of `dataset` is transferred to `tensor`. +Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ diff --git a/tensorflow/core/kernels/data/dataset.h b/tensorflow/core/kernels/data/dataset.h index 8238661..2c6fc8d 100644 --- a/tensorflow/core/kernels/data/dataset.h +++ b/tensorflow/core/kernels/data/dataset.h @@ -15,600 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ -#include - -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/framework/variant_encode_decode.h" -#include "tensorflow/core/framework/variant_tensor_data.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/tracing.h" - -// Polymorphic datasets should support all primitive TensorFlow -// types. Use this macro to expand `m(T)` once for each primitive type -// `T`, e.g. to build a `switch` statement. -#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) - -namespace tensorflow { - -// Interface for reading values from a key-value store. -// Used for restoring iterator state. -class IteratorStateReader { - public: - virtual Status ReadScalar(StringPiece key, int64* val) = 0; - virtual Status ReadScalar(StringPiece key, string* val) = 0; - virtual Status ReadTensor(StringPiece key, Tensor* val) = 0; - virtual bool Contains(StringPiece key) = 0; - - virtual ~IteratorStateReader() {} -}; - -// Interface for writing values to a key-value store. -// Used for saving iterator state. -class IteratorStateWriter { - public: - virtual Status WriteScalar(StringPiece key, const int64 val) = 0; - virtual Status WriteScalar(StringPiece key, const string& val) = 0; - virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; - - virtual ~IteratorStateWriter() {} -}; - -// Forward declarations to avoid introducing a dependency on headers in -// "tensorflow/core/graph/...". -class GraphDefBuilder; -class GraphDatasetBase; -class Node; - -// Wrapper around GraphDefBuilder. Used to serialize Dataset graph. -class GraphDefBuilderWrapper { - public: - explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} - - // Adds a Const node with scalar value to the Graph. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - template - Status AddScalar(const T& val, Node** output) { - Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); - val_t.scalar()() = val; - AddTensorInternal(val_t, output); - if (*output == nullptr) { - return errors::Internal("AddScalar: Failed to build Const op."); - } - return Status::OK(); - } - - // Adds a Const node with vector value to the Graph. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? - template - Status AddVector(const std::vector& val, Node** output) { - Tensor val_t = Tensor(DataTypeToEnum::v(), - TensorShape({static_cast(val.size())})); - for (int i = 0; i < val.size(); i++) { - val_t.flat()(i) = val[i]; - } - AddTensorInternal(val_t, output); - if (*output == nullptr) { - return errors::Internal("AddVector: Failed to build Const op."); - } - return Status::OK(); - } - - // Adds a Const node with Tensor value to the Graph. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - Status AddTensor(const Tensor& val, Node** output) { - AddTensorInternal(val, output); - if (*output == nullptr) { - return errors::Internal("AddTensor: Failed to build Const op."); - } - return Status::OK(); - } - - Status AddDataset(const GraphDatasetBase* dataset, - const std::vector& inputs, Node** output) { - return AddDataset(dataset, inputs, {}, output); - } - - // Adds a node corresponding to the `DatasetType` to the Graph. - // Return value of `DatasetType::op_name()` is used as the op type for the - // node. - // Values for the output_types and output_shapes node attributes are also - // written if those attributes are defined in the OpDef. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - Status AddDataset(const GraphDatasetBase* dataset, - const std::vector& inputs, - const std::vector>& attrs, - Node** output) { - std::vector> enumerated_inputs(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - enumerated_inputs[i] = std::make_pair(i, inputs[i]); - } - return AddDataset(dataset, enumerated_inputs, {}, attrs, output); - } - - Status AddDataset( - const GraphDatasetBase* dataset, - const std::vector>& inputs, - const std::vector>>& list_inputs, - const std::vector>& attrs, - Node** output); - - // Adds a user-defined function with name `function_name` to the graph and - // recursively adds all functions it references. If a function with a matching - // name has already been added, returns with OK status. If a user-defined with - // name `function_name` is not found in the FunctionLibraryDefinition, returns - // an InvalidArgumentError. If the function with name `function_name` or any - // of its dependent functions are stateful, returns an InvalidArgument error. - Status AddFunction(OpKernelContext* ctx, const string& function_name); - - template - void BuildAttrValue(const T& value, AttrValue* attr) { - SetAttrValue(value, attr); - } - - private: - void AddTensorInternal(const Tensor& val, Node** output); - - Status EnsureFunctionIsStateless(OpKernelContext* ctx, - const string& function_name) const { - const FunctionLibraryDefinition* lib_def = - ctx->function_library()->GetFunctionLibraryDefinition(); - const FunctionDef* function_def = lib_def->Find(function_name); - if (!function_def) { - return errors::InvalidArgument("Unable to find FunctionDef for ", - function_name, " in registry."); - } - for (const NodeDef& node_def : function_def->node_def()) { - const OpDef* op_def; - TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def)); - // TODO(b/65524810): Hack to allow functions to capture Dataset op - // nodes needed for FlatMap. Currently, source datasets nodes have been - // marked stateful to avoid constant folding since we do not have a - // good way of serializing them. - if (IsOpWhitelisted(op_def)) { - continue; - } - if (op_def->is_stateful()) { - return errors::InvalidArgument( - "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ", - "in function ", function_name, " is stateful. ", - "Saving stateful functions is not supported yet."); - } - } - return Status::OK(); - } - - // Returns whether an op has been whitelisted for use inside map_fns. - // Uses a heuristic to whitelist source dataset ops which have been - // marked stateful due to b/65524810. - // Also looks up the `op_def->name` in the global - // `WhitelistedStatefulOpRegistry`. - bool IsOpWhitelisted(const OpDef* op_def) const { - return (StringPiece(op_def->name()).ends_with("Dataset") && - op_def->output_arg_size() == 1 && - op_def->output_arg(0).type() == DT_VARIANT) || - dataset::WhitelistedStatefulOpRegistry::Global()->Contains( - op_def->name()); - } - - bool HasAttr(const string& op_type_name, const string& attr_name) const; - - bool HasAttr(const OpDef* op_def, const string& attr_name) const { - for (auto attr : op_def->attr()) { - if (attr.name() == attr_name) { - return true; - } - } - return false; - } - - Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) { - if (attr_value.has_func()) { - TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); - } else if (attr_value.has_list()) { - for (const NameAttrList& name_attr_list : attr_value.list().func()) { - TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); - } - } - return Status::OK(); - } - - GraphDefBuilder* b_; -}; - -class StatsAggregator; - -// A cut-down version of OpKernelContext for running computations in -// iterators. Note that we cannot simply use OpKernelContext here -// because we might run computation in an iterator whose lifetime is -// not nested within the lifetime of a single OpKernelContext -// (e.g. asynchronous prefetching). -// -// TODO(mrry): We will probably need to support more of -// OpKernelContext here. For example, should allocation be handled by -// the IteratorContext? -// TODO(mrry): We're making some daring assumptions about the lifetime -// of the runner passed in here. A runner will be deleted when the original -// step ends, but all existing runners only close over session-lifetime (or -// longer-lived) state, so we can make a copy of the function. There's nothing -// in the definition of the API from which we took the runner to guarantee that -// what we are doing is safe. We should formalize the properties here. -class IteratorContext { - public: - struct Params { - // Interface to operating system functionality. - Env* env; - - // Function call support. - std::function)> runner = nullptr; - - // A function that returns the current `StatsAggregator` instance to be - // used when recording statistics about the iterator. - // - // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator` - // is a property of the `IteratorResource` (which this class does not know - // about), and (ii) it can change after the `IteratorContext` has been - // created. Better suggestions are welcome! - std::function()> stats_aggregator_getter = - nullptr; - - // The FunctionLibraryRuntime object to be used to make function calls. - FunctionLibraryRuntime* lib = nullptr; - std::shared_ptr function_library = nullptr; - - // The Allocator to be used to allocate the output of an iterator. - Allocator* allocator = nullptr; - }; - - explicit IteratorContext(Params params) : params_(std::move(params)) {} - - Env* env() const { return params_.env; } - - std::function)>* runner() { - return ¶ms_.runner; - } - - std::shared_ptr stats_aggregator() { - if (params_.stats_aggregator_getter) { - return params_.stats_aggregator_getter(); - } else { - return nullptr; - } - } - - std::shared_ptr function_library() { - return params_.function_library; - } - - FunctionLibraryRuntime* lib() { return params_.lib; } - - void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; } - - Allocator* allocator(AllocatorAttributes attrs); - - private: - Params params_; -}; - -// Represents the current position in a range of outputs, where the -// range of outputs is typically represented by an `DatasetBase`, -// defined below. -class IteratorBase { - public: - virtual ~IteratorBase() {} - - // Gets the next output from the range that this iterator is traversing. - // - // If at least one output remains in this iterator's range, that - // output will be stored in `*out_tensors` and `false` will be - // stored in `*end_of_sequence`. - // - // If no more outputs remain in this iterator's range, `true` will - // be stored in `*end_of_sequence`, and the content of - // `*out_tensors` will be undefined. - // - // This method is thread-safe. - // - // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and - // potentially remove this method. - virtual Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) = 0; - - // Returns a vector of DataType values, representing the respective - // element types of each tuple component in the outputs of this - // iterator. - virtual const DataTypeVector& output_dtypes() const = 0; - - // Returns a vector of tensor shapes, representing the respective - // (and possibly partially defined) shapes of each tuple component - // in the outputs of this iterator. - virtual const std::vector& output_shapes() const = 0; - - // Saves the state of this iterator. - virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { - return SaveInternal(writer); - } - - // Restores the state of this iterator. - virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) { - return RestoreInternal(ctx, reader); - } - - protected: - // This is needed so that sub-classes of IteratorBase can call - // `SaveInternal` on their parent iterators, e.g., in - // `RepeatDataasetOp::Dataset`. - Status SaveParent(IteratorStateWriter* writer, - const std::unique_ptr& parent) { - return parent->SaveInternal(writer); - } - - // This is needed so that sub-classes of IteratorBase can call - // `RestoreInternal` on their parent iterators, e.g., in - // `RepeatDataasetOp::Dataset`. - Status RestoreParent(IteratorContext* ctx, IteratorStateReader* reader, - const std::unique_ptr& parent) { - return parent->RestoreInternal(ctx, reader); - } - - // Saves the state of this iterator recursively. - virtual Status SaveInternal(IteratorStateWriter* writer) { - return errors::Unimplemented("SaveInternal"); - } - - // Restores the state of this iterator recursively. - virtual Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) { - return errors::Unimplemented("RestoreInternal"); - } -}; - -// Represents a (potentially infinite) range of outputs, where each -// output is a tuple of tensors. -class DatasetBase : public core::RefCounted { - public: - // Returns a new iterator for iterating over the range of elements in - // this dataset. - // - // This method may be called multiple times on the same instance, - // and the resulting iterators will have distinct state. Each - // iterator will traverse all elements in this dataset from the - // start. - // - // Ownership of the created iterator will be transferred to the caller. - // - // The prefix identifies the sequence of iterators leading up to the newly - // created iterator. - virtual std::unique_ptr MakeIterator( - const string& prefix) const = 0; - - // Returns a vector of DataType values, representing the respective - // element types of each tuple component in the outputs of this - // dataset. - virtual const DataTypeVector& output_dtypes() const = 0; - - // Returns a vector of tensor shapes, representing the respective - // (and possibly partially defined) shapes of each tuple component - // in the outputs of this dataset. - virtual const std::vector& output_shapes() const = 0; - - // A human-readable debug string for this dataset. - virtual string DebugString() = 0; - - // Serializes the dataset and writes it to the `writer`. - virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { - return errors::Unimplemented("DatasetBase::Save"); - } - - protected: - // TODO(srbs): Ideally all graph related logic should reside in - // GraphDatasetBase. However, that would require Datasets defined in all ops - // to derive from GraphDatasetBase. Once that is done we can move - // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase. - class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { - public: - DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} - Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset, - Node** output) { - return dataset->AsGraphDefInternal(ctx, this, output); - } - }; - - virtual Status AsGraphDefInternal(OpKernelContext* ctx, - DatasetGraphDefBuilder* b, - Node** node) const { - return AsGraphDefInternal(b, node); - } - - virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, - Node** node) const { - return errors::Unimplemented("AsGraphDefInternal"); - } -}; - -// Base-class for datasets that are built by ops. -class GraphDatasetBase : public DatasetBase { - public: - GraphDatasetBase(OpKernelContext* ctx) - : op_name_(ctx->op_kernel().type_string()) {} - - const string op_name() const { return op_name_; } - - Status Save(OpKernelContext* ctx, - IteratorStateWriter* writer) const override { - string serialized_graph_def; - string output_node; - TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); - return Status::OK(); - } - - // Key for storing the Dataset graph in the serialized format. - static const char kDatasetGraphKey[]; - - // Key for storing the output node of the Dataset graph in the serialized - // format. - static const char kDatasetGraphOutputNodeKey[]; - - private: - Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, - string* output_node) const; - - const string op_name_; -}; - -// Represents an iterator that is associated with a particular parent dataset. -template -class DatasetIterator : public IteratorBase { - public: - struct Params { - // Owns one reference on the shared dataset resource. - const DatasetType* dataset; - - // Identifies the sequence of iterators leading up to this iterator. - const string prefix; - }; - - explicit DatasetIterator(const Params& params) : params_(params) { - params_.dataset->Ref(); - } - - ~DatasetIterator() override { params_.dataset->Unref(); } - - // The dataset from which this iterator was created. - const DatasetType* dataset() const { return params_.dataset; } - - // The sequence of iterators leading up to this iterator. - const string prefix() const { return params_.prefix; } - - const DataTypeVector& output_dtypes() const override { - return params_.dataset->output_dtypes(); - } - - const std::vector& output_shapes() const override { - return params_.dataset->output_shapes(); - } - - Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) final { - port::Tracing::TraceMe activity(params_.prefix); - Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); - if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { - s = errors::Internal( - "Iterator \"", params_.prefix, - "\" returned OutOfRange without setting `*end_of_sequence`. This " - "indicates that an error may have occurred. Original message: ", - s.error_message()); - LOG(ERROR) << s; - } - return s; - } - - Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final { - TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer)); - return IteratorBase::Save(ctx, writer); - } - - protected: - // Internal implementation of GetNext that is wrapped in tracing logic. - virtual Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) = 0; - - string full_name(const string& name) const { - return strings::StrCat(prefix(), ":", name); - } - - private: - Params params_; -}; - -// Encapsulates the work required to plug a DatasetBase into the core TensorFlow -// graph execution engine. -class DatasetOpKernel : public OpKernel { - public: - DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} - void Compute(OpKernelContext* ctx) final; - - protected: - // Subclasses should implement this method. It will be called during Compute - // execution. - virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; - - template - Status ParseScalarArgument(OpKernelContext* ctx, - const StringPiece& argument_name, T* output) { - const Tensor* argument_t; - TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); - if (!TensorShapeUtils::IsScalar(argument_t->shape())) { - return errors::InvalidArgument(argument_name, " must be a scalar"); - } - *output = argument_t->scalar()(); - return Status::OK(); - } -}; - -// Encapsulates the work required to plug unary Datasets into the core -// TensorFlow graph execution engine. -class UnaryDatasetOpKernel : public DatasetOpKernel { - public: - UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} - - protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; - virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) = 0; -}; - -// Encapsulates the work required to plug binary Datasets into the core -// TensorFlow graph execution engine. -class BinaryDatasetOpKernel : public DatasetOpKernel { - public: - BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} - - protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; - virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase* another_input, - DatasetBase** output) = 0; -}; - -// Validates and extracts a `DatasetBase` object from `tensor`. -// -// `tensor` must have been written by a call to SetVariantTensorToDataset(). -// -// The retrieved pointer is a borrowed reference to the dataset, which is owned -// by the tensor. The consumer must either acquire its own reference to the -// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not -// destroyed or mutated while the retrieved pointer is in use. -Status GetDatasetFromVariantTensor(const Tensor& tensor, - DatasetBase** out_dataset); - -// Stores a `DatasetBase` object in `tensor`. -// -// The ownership of `dataset` is transferred to `tensor`. -Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); - -} // namespace tensorflow +#include "tensorflow/core/framework/dataset.h" #endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ -- 2.7.4