From e140ab8d07dd6aec70f61c0c6939506f6e67ac5e Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 25 May 2018 17:05:33 -0700 Subject: [PATCH] [tf.data] Fixing concurrency issue in `map_and_batch`. PiperOrigin-RevId: 198124860 --- tensorflow/core/kernels/data/map_and_batch_dataset_op.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 879bb40..f41a810 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -211,6 +211,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { + mutex_lock external_l(external_mu_); mutex_lock l(mu_); EnsureRunnerThreadStarted(ctx); BatchResult* result = &batch_results_[ComputeIndex(input_batch_)]; @@ -220,6 +221,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock external_l(external_mu_); mutex_lock l(mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { @@ -243,6 +245,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { + mutex_lock external_l(external_mu_); mutex_lock l(mu_); TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR( @@ -629,6 +632,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + // Used for coordination between the main thread, the runner thread, and + // the callback threads. mutex mu_; // Used for coordination between the main thread, the runner thread, and // the callback threads. In particular, the runner thread should only @@ -636,6 +641,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { // user specified level of parallelism and there are slots available in // the `batch_results_` buffer. condition_variable cond_var_; + // Used for serializing external parallelism. + mutex external_mu_ ACQUIRED_BEFORE(mu_); // Counts the number of outstanding calls for this batch. int64 num_calls_ GUARDED_BY(mu_) = 0; // Counts the total number of calls. -- 2.7.4