[tf.data] Fixing concurrency issue in `map_and_batch`.
authorJiri Simsa <jsimsa@google.com>
Sat, 26 May 2018 00:05:33 +0000 (17:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 26 May 2018 00:08:13 +0000 (17:08 -0700)
PiperOrigin-RevId: 198124860

tensorflow/core/kernels/data/map_and_batch_dataset_op.cc

index 879bb40..f41a810 100644 (file)
@@ -211,6 +211,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
+        mutex_lock external_l(external_mu_);
         mutex_lock l(mu_);
         EnsureRunnerThreadStarted(ctx);
         BatchResult* result = &batch_results_[ComputeIndex(input_batch_)];
@@ -220,6 +221,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
 
      protected:
       Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock external_l(external_mu_);
         mutex_lock l(mu_);
         // Wait for all in-flight calls to complete.
         while (num_calls_ > 0) {
@@ -243,6 +245,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
 
       Status RestoreInternal(IteratorContext* ctx,
                              IteratorStateReader* reader) override {
+        mutex_lock external_l(external_mu_);
         mutex_lock l(mu_);
         TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
         TF_RETURN_IF_ERROR(
@@ -629,6 +632,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
         return Status::OK();
       }
 
+      // Used for coordination between the main thread, the runner thread, and
+      // the callback threads.
       mutex mu_;
       // Used for coordination between the main thread, the runner thread, and
       // the callback threads. In particular, the runner thread should only
@@ -636,6 +641,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
       // user specified level of parallelism and there are slots available in
       // the `batch_results_` buffer.
       condition_variable cond_var_;
+      // Used for serializing external parallelism.
+      mutex external_mu_ ACQUIRED_BEFORE(mu_);
       // Counts the number of outstanding calls for this batch.
       int64 num_calls_ GUARDED_BY(mu_) = 0;
       // Counts the total number of calls.