[tf.data] Add `IteratorContext::allocator()`.
authorDerek Murray <mrry@google.com>
Tue, 30 Jan 2018 21:26:51 +0000 (13:26 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 30 Jan 2018 21:46:56 +0000 (13:46 -0800)
This enables the various iterator implementations to use the actual allocator for the device on which they are running, rather than defaulting to `cpu_allocator()` (which is typically a plain malloc). In future, this will enable allocating iterator outputs in CUDA-pinned memory (and GPU memory).

PERFORMANCE NOTE: In sessions where `ConfigProto.force_gpu_compatible == True`, this change has the effect of allocating all input pipeline tensors in CUDA-pinned memory. Previous if this flag was set, only the tensors allocated during function execution would be allocated in this space, and other tensors (e.g. the result of a `Dataset.batch()` would be allocated using `cpu_allocator()` (i.e. `malloc()`). This change should lead to more efficient communication between a host-side input pipeline and GPUs, but it may also create more pressure on the CUDA host allocator (whose default maximum size is 64GB). The "TF_CUDA_HOST_MEM_LIMIT_IN_MB" environment variable can be used to override this value.

This change is a starting point for working on issue #13610.

PiperOrigin-RevId: 183881907

16 files changed:
tensorflow/core/kernels/data/BUILD
tensorflow/core/kernels/data/batch_dataset_op.cc
tensorflow/core/kernels/data/dataset.cc
tensorflow/core/kernels/data/dataset.h
tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
tensorflow/core/kernels/data/padded_batch_dataset_op.cc
tensorflow/core/kernels/data/random_dataset_op.cc
tensorflow/core/kernels/data/range_dataset_op.cc
tensorflow/core/kernels/data/reader_dataset_ops.cc
tensorflow/core/kernels/data/sql/BUILD
tensorflow/core/kernels/data/sql/query_connection.h
tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
tensorflow/core/kernels/data/sql/sqlite_query_connection.h
tensorflow/core/kernels/data/sql_dataset_ops.cc
tensorflow/core/kernels/data/tensor_slice_dataset_op.cc

index 45505ef716fa801e4740424374aeb4fe8f5a29b7..cdb402386137c5752be5bdf6dcc7414f0879fd47 100644 (file)
@@ -49,6 +49,7 @@ cc_library(
     srcs = ["dataset.cc"],
     hdrs = ["dataset.h"],
     deps = [
+        "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
         "//tensorflow/core:lib",
index 0853362b268a035f7cf6340e0bd1310707604bed..7fa67efb9e22e6877b97524150b9024521619dbc 100644 (file)
@@ -144,7 +144,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
           const Tensor& first_element = batch_elements[0][component_index];
           TensorShape batch_component_shape({num_batch_elements});
           batch_component_shape.AppendShape(first_element.shape());
-          Tensor batch_component(cpu_allocator(), first_element.dtype(),
+          Tensor batch_component(ctx->allocator({}), first_element.dtype(),
                                  batch_component_shape);
           // Build the output tuple component by copying one slice
           // from each input element in the batch.
index 2ea6875567604e4e5bf7c990ad6a42ed8c5dafaa..d18cb160189e832592b2bfdf7769396010859cc6 100644 (file)
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
 #include "tensorflow/core/graph/node_builder.h"
 
@@ -264,6 +265,10 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
   MakeDataset(ctx, input, another_input, output);
 }
 
+Allocator* IteratorContext::allocator(AllocatorAttributes attrs) {
+  return params_.lib->device()->GetAllocator(attrs);
+}
+
 const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
 const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
     "_DATASET_GRAPH_OUTPUT_NODE";
index 2ef31ddfaaa2fd1bd6a4898726d788d1ceece82e..08c3ca82eab5c79ced6e8de7fc02f88782d0564a 100644 (file)
@@ -272,6 +272,9 @@ class IteratorContext {
     // The FunctionLibraryRuntime object to be used to make function calls.
     FunctionLibraryRuntime* lib = nullptr;
     std::shared_ptr<const FunctionLibraryDefinition> function_library = nullptr;
+
+    // The Allocator to be used to allocate the output of an iterator.
+    Allocator* allocator = nullptr;
   };
 
   explicit IteratorContext(Params params) : params_(std::move(params)) {}
@@ -298,6 +301,8 @@ class IteratorContext {
 
   void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; }
 
+  Allocator* allocator(AllocatorAttributes attrs);
+
  private:
   Params params_;
 };
index e7224bb547f60f943c7c91c37edfbbf561f5351a..132808a5f140a31fc3c1852cb83e5cd8579b6d95 100644 (file)
@@ -155,7 +155,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
 
         // Determine the size of the output tensors:
         // * dense_shape will be [`row_shape + 1`].
-        Tensor dense_shape(cpu_allocator(), DT_INT64, {row_ndims + 1});
+        Tensor dense_shape(ctx->allocator({}), DT_INT64, {row_ndims + 1});
         auto dense_shape_vec = dense_shape.vec<int64>();
         for (size_t i = 0; i < row_ndims; ++i) {
           if (row_shape.dim_size(i) == -1) {
@@ -215,10 +215,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
 
         // * indices will be [`total_elements`, `row_shape + 1`].
         // * values will be [`total_elements`].
-        Tensor indices(cpu_allocator(), DT_INT64,
+        Tensor indices(ctx->allocator({}), DT_INT64,
                        {total_elements, row_ndims + 1});
         Tensor values(
-            cpu_allocator(),
+            ctx->allocator({}),
             DatasetIterator<Dataset<T>>::dataset()->input_->output_dtypes()[0],
             {total_elements});
         auto indices_matrix = indices.matrix<int64>();
index c529f671f2bb7fd3eb5277c23867e25ba70fd046..9ce263732f6e6c907dfdc89692455daa5dca86d1 100644 (file)
@@ -183,7 +183,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
               TensorShape component_shape(
                   batch_results_[current_batch_index_].output[i].shape());
               component_shape.set_dim(0, num_elements);
-              Tensor component(cpu_allocator(), output[i].dtype(),
+              Tensor component(ctx->allocator({}), output[i].dtype(),
                                component_shape);
               TF_RETURN_IF_ERROR(
                   CopyPartialBatch(&component, output[i], num_elements));
@@ -244,7 +244,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
         return Status::OK();
       }
 
-      void EnsureOutputAllocated(BatchResult* batch_result,
+      void EnsureOutputAllocated(IteratorContext* ctx,
+                                 BatchResult* batch_result,
                                  const std::vector<Tensor>& return_values) {
         mutex_lock l(batch_result->mu);
         if (batch_result->output_allocated) {
@@ -254,7 +255,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
         for (size_t i = 0; i < num_components; ++i) {
           TensorShape component_shape({dataset()->batch_size_});
           component_shape.AppendShape(return_values[i].shape());
-          Tensor component(cpu_allocator(), return_values[i].dtype(),
+          Tensor component(ctx->allocator({}), return_values[i].dtype(),
                            component_shape);
           batch_result->output.emplace_back(std::move(component));
         }
@@ -285,10 +286,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
               dataset()->captured_func_->RunAsync(
                   ctx, std::move(input_element), &result->return_values,
                   [this, ctx, result, batch_result, offset](Status ret_status) {
-                    delete ctx;
                     result->status.Update(ret_status);
                     if (ret_status.ok()) {
-                      EnsureOutputAllocated(batch_result,
+                      EnsureOutputAllocated(ctx, batch_result,
                                             result->return_values);
                       const size_t num_components =
                           result->return_values.size();
@@ -318,6 +318,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
                         }
                       }
                     }
+                    delete ctx;
                     // NOTE(mrry): We clear the return values here to release
                     // any memory associated with them and to paralellize the
                     // destruction of the tensors (which can be surprisingly
index 346eca0bb2ab1c7a82ddba98063c0ccb71b4e58f..4fe4e8e2940b3725d9ac1fc1a508ea2cccfe79fe 100644 (file)
@@ -376,7 +376,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
 
           // 2. Copy each batch element to the appropriate location in
           // the output component tensor.
-          Tensor batch_component(cpu_allocator(),
+          Tensor batch_component(ctx->allocator({}),
                                  output_dtypes()[component_index],
                                  batch_component_shape);
           TF_RETURN_IF_ERROR(SetElementZero(
index bc638864b0147f4d71b3382ea320453e972ba8d7..210b9ad1b84eeb0c106b0ee538b4957aba7ce1b2 100644 (file)
@@ -99,7 +99,7 @@ class RandomDatasetOp : public DatasetOpKernel {
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
         mutex_lock l(mu_);
-        Tensor value_tensor(cpu_allocator(), DT_INT64, {});
+        Tensor value_tensor(ctx->allocator({}), DT_INT64, {});
         value_tensor.scalar<int64>()() = Random();
         out_tensors->emplace_back(std::move(value_tensor));
         *end_of_sequence = false;
index d0bc61acd99afae14ddc8a3e678acb4197fcea71..b57518e678ed185a183e0413d6e90f2a9f85e9fc 100644 (file)
@@ -100,7 +100,7 @@ class RangeDatasetOp : public DatasetOpKernel {
           *end_of_sequence = true;
           return Status::OK();
         }
-        Tensor value_tensor(cpu_allocator(), DT_INT64, {});
+        Tensor value_tensor(ctx->allocator({}), DT_INT64, {});
         value_tensor.scalar<int64>()() = next_;
         out_tensors->emplace_back(std::move(value_tensor));
         *end_of_sequence = false;
index aa39fffc2e344db8143b700cbba4c29bdb134964..34d7d9f914d7a726135febabb1fbe35b0146977c 100644 (file)
@@ -141,7 +141,7 @@ class TextLineDatasetOp : public DatasetOpKernel {
 
             if (s.ok()) {
               // Produce the line as output.
-              Tensor line_tensor(cpu_allocator(), DT_STRING, {});
+              Tensor line_tensor(ctx->allocator({}), DT_STRING, {});
               line_tensor.scalar<string>()() = line_contents;
               out_tensors->emplace_back(std::move(line_tensor));
               *end_of_sequence = false;
@@ -384,7 +384,7 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
               TF_RETURN_IF_ERROR(
                   input_buffer_->ReadNBytes(dataset()->record_bytes_, &record));
               // Produce the record as output.
-              Tensor record_tensor(cpu_allocator(), DT_STRING, {});
+              Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
               record_tensor.scalar<string>()() = record;
               out_tensors->emplace_back(std::move(record_tensor));
               *end_of_sequence = false;
@@ -589,7 +589,7 @@ class TFRecordDatasetOp : public DatasetOpKernel {
         do {
           // We are currently processing a file, so try to read the next record.
           if (reader_) {
-            Tensor result_tensor(cpu_allocator(), DT_STRING, {});
+            Tensor result_tensor(ctx->allocator({}), DT_STRING, {});
             Status s = reader_->ReadRecord(&result_tensor.scalar<string>()());
             if (s.ok()) {
               out_tensors->emplace_back(std::move(result_tensor));
index 0286825af3ef7c04fff6911ddf7daec76479a715..f4698bdaf7ae9767e068e49dad61d2a3d9f739a8 100644 (file)
@@ -33,6 +33,7 @@ cc_library(
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
+        "//tensorflow/core/kernels/data:dataset",
         "//tensorflow/core/lib/db:sqlite",
     ],
 )
index f31017bd1981c3809d9b7daaa2dc56256d19d914..e9ffca202ff32f0c0130427c2699ce0449a0903a 100644 (file)
@@ -19,6 +19,8 @@ limitations under the License.
 
 namespace tensorflow {
 
+class IteratorContext;
+
 namespace sql {
 // This interface allows a user to connect to a database, execute a query, and
 // iterate over the result set, putting the results into an output tensor.
@@ -56,7 +58,7 @@ class QueryConnection {
   // If there are no more rows in the result set, then instead `true` will be
   // stored in `*end_of_sequence`, and the content of `*out_tensors` will be
   // undefined.
-  virtual Status GetNext(std::vector<Tensor>* out_tensors,
+  virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
                          bool* end_of_sequence) = 0;
 };
 
index 029a0aab97290e30783e415274323a1e43f9740b..7cd07bd8eca160bfc62e15adc568742c84711779 100644 (file)
@@ -15,6 +15,7 @@ limitations under the License.
 #include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h"
 
 #include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/data/dataset.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 
 namespace tensorflow {
@@ -48,14 +49,16 @@ Status SqliteQueryConnection::Close() {
   return Status::OK();
 }
 
-Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
+Status SqliteQueryConnection::GetNext(IteratorContext* ctx,
+                                      std::vector<Tensor>* out_tensors,
                                       bool* end_of_sequence) {
   if (!stmt_) TF_RETURN_IF_ERROR(PrepareQuery());
   TF_RETURN_IF_ERROR(stmt_.Step(end_of_sequence));
   if (!*end_of_sequence) {
     for (int i = 0; i < column_count_; i++) {
       DataType dt = output_types_[i];
-      Tensor tensor(cpu_allocator(), dt, {});
+      // TODO(mrry): Pass in the `IteratorContext::allocator()`.
+      Tensor tensor(ctx->allocator({}), dt, {});
       FillTensorWithResultSetEntry(dt, i, &tensor);
       out_tensors->emplace_back(std::move(tensor));
     }
index 787c17d6c00d99afad3d7814c3c2daaf4295b1b3..81b19530b7d5964e17bde996de9fa7766af318b7 100644 (file)
@@ -32,7 +32,7 @@ class SqliteQueryConnection : public QueryConnection {
   Status Open(const string& data_source_name, const string& query,
               const DataTypeVector& output_types) override;
   Status Close() override;
-  Status GetNext(std::vector<Tensor>* out_tensors,
+  Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
                  bool* end_of_sequence) override;
 
  private:
index 72302190802d17f2cb1ed5471017180238aedff3..d50e9c9cf9739044379c7bbe753fc4acc2de311e 100644 (file)
@@ -116,7 +116,7 @@ class SqlDatasetOp : public DatasetOpKernel {
         }
       }
 
-      Status GetNextInternal(IteratorContext* /*ctx*/,
+      Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
         mutex_lock l(mu_);
@@ -132,7 +132,7 @@ class SqlDatasetOp : public DatasetOpKernel {
             return s;
           }
         }
-        return query_connection_->GetNext(out_tensors, end_of_sequence);
+        return query_connection_->GetNext(ctx, out_tensors, end_of_sequence);
       }
 
      private:
index 18adae1ea32316ffd995a95fb25198309fda3361..d5be4c778074e406122dc3a1a9c23681fca491d0 100644 (file)
@@ -117,7 +117,7 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
           out_tensors->reserve(dataset()->tensors_.size());
           for (int i = 0; i < dataset()->tensors_.size(); ++i) {
             const Tensor& t = dataset()->tensors_[i];
-            Tensor t_slice(cpu_allocator(), t.dtype(),
+            Tensor t_slice(ctx->allocator({}), t.dtype(),
                            TensorShape(dataset()->shapes_[i].dim_sizes()));
             TF_RETURN_IF_ERROR(batch_util::CopySliceToElement(t, &t_slice, i_));
             out_tensors->emplace_back(std::move(t_slice));