From f83843a4b8dde5e9306c2b91da8ccbd438a7265f Mon Sep 17 00:00:00 2001 From: Michael Kuperstein Date: Tue, 10 Apr 2018 16:45:19 -0700 Subject: [PATCH] Add a thread-safe producer-consumer queue. PiperOrigin-RevId: 192370670 --- tensorflow/compiler/jit/BUILD | 19 +++ tensorflow/compiler/jit/producer_consumer_queue.h | 132 +++++++++++++++++++ .../compiler/jit/producer_consumer_queue_test.cc | 139 +++++++++++++++++++++ 3 files changed, 290 insertions(+) create mode 100644 tensorflow/compiler/jit/producer_consumer_queue.h create mode 100644 tensorflow/compiler/jit/producer_consumer_queue_test.cc diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a492fc6..4cefc08 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -318,6 +318,25 @@ cc_library( hdrs = ["union_find.h"], ) +cc_library( + name = "producer_consumer_queue", + hdrs = ["producer_consumer_queue.h"], + deps = ["//tensorflow/core:lib"], +) + +tf_cc_test( + name = "producer_consumer_queue_test", + size = "small", + srcs = ["producer_consumer_queue_test.cc"], + deps = [ + ":producer_consumer_queue", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "graph_to_functiondef_test", size = "small", diff --git a/tensorflow/compiler/jit/producer_consumer_queue.h b/tensorflow/compiler/jit/producer_consumer_queue.h new file mode 100644 index 0000000..7c8c041 --- /dev/null +++ b/tensorflow/compiler/jit/producer_consumer_queue.h @@ -0,0 +1,132 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ +#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ + +#include +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// A thread-safe, first-in-first-out queue. +template +class ProducerConsumerQueue { + public: + ProducerConsumerQueue() + : capacity_(std::numeric_limits::max()) {} + ~ProducerConsumerQueue() = default; + + // Wait until the queue is non-full, then append a copy of v. + void Put(const T &v); + + // Wait until the queue is non-empty, then remove and return the head value. + T Get(); + + // If the queue is non-empty, remove the head value, placing it in *pv, and + // return true; otherwise return false. + bool TryGet(T *pv); + + // Set the capacity of the queue; the queue is full whenever count() >= + // capacity(). The initial value is the maximum size_t. Requires size > 0. + void set_capacity(std::size_t size); + + // Return the capacity of the queue. + std::size_t capacity() const; + + // Return the number of elements in the queue. + std::size_t count() const; + + // Implementation details follow. Clients should ignore. + private: + mutable tensorflow::mutex mu_; // protects all fields below + tensorflow::condition_variable non_empty_ GUARDED_BY(mu_); + tensorflow::condition_variable non_full_ GUARDED_BY(mu_); + std::size_t capacity_ GUARDED_BY(mu_); + std::deque queue_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue); +}; + +// ------------------------------------------------------ +// Implementation details follow. Clients should ignore. + +// Wait until the queue is non-full, then append a copy of v. +template +void ProducerConsumerQueue::Put(const T &v) { + mutex_lock lock(mu_); + while (queue_.size() >= capacity_) { + non_full_.wait(lock); + } + queue_.push_back(v); + non_empty_.notify_one(); +} + +// Wait until the queue is non-empty, then remove and return the head value. +template +T ProducerConsumerQueue::Get() { + mutex_lock lock(mu_); + while (queue_.empty()) { + non_empty_.wait(lock); + } + non_full_.notify_one(); + T result_value = queue_.front(); + queue_.pop_front(); + return result_value; +} + +// If the queue is non-empty, remove the head value, placing it in *pv, and +// return true; otherwise return false. +template +bool ProducerConsumerQueue::TryGet(T *pv) { + mutex_lock lock(mu_); + bool got_element = !queue_.empty(); + if (got_element) { + non_full_.notify_one(); + *pv = queue_.front(); + queue_.pop_front(); + } + return got_element; +} + +// Set the capacity of the queue; the queue is full whenever count() >= +// capacity(). The initial value is the maximum size_t. Requires size > 0. +template +void ProducerConsumerQueue::set_capacity(std::size_t size) { + mutex_lock lock(mu_); + CHECK_NE(size, 0); + capacity_ = size; + non_full_.notify_all(); +} + +// Return the capacity of the queue. +template +std::size_t ProducerConsumerQueue::capacity() const { + mutex_lock lock(mu_); + std::size_t max_elements = capacity_; + return max_elements; +} + +// Return the number of elements in the queue. +template +std::size_t ProducerConsumerQueue::count() const { + mutex_lock lock(mu_); + std::size_t num_elements = queue_.size(); + return num_elements; +} +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ diff --git a/tensorflow/compiler/jit/producer_consumer_queue_test.cc b/tensorflow/compiler/jit/producer_consumer_queue_test.cc new file mode 100644 index 0000000..f61260c --- /dev/null +++ b/tensorflow/compiler/jit/producer_consumer_queue_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/producer_consumer_queue.h" + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +typedef ProducerConsumerQueue IntQueue; + +// Insert integers between low inclusive and high exclusive into q. +void PushRange(IntQueue *q, int low, int high) { + while (low != high) { + q->Put(low); + VLOG(2) << "Pushing " << low; + ++low; + } +} + +// Push the numbers between 0 and 999 inclusive from several threads in the +// pool. +void PushRanges(IntQueue *queue, thread::ThreadPool *pool) { + VLOG(1) << "Adding 20-36"; + pool->Schedule([queue] { PushRange(queue, 20, 36); }); + VLOG(1) << "Adding 7-20"; + pool->Schedule([queue] { PushRange(queue, 7, 20); }); + VLOG(1) << "Adding 36-501"; + pool->Schedule([queue] { PushRange(queue, 36, 501); }); + VLOG(1) << "Adding 501-1000"; + pool->Schedule([queue] { PushRange(queue, 501, 1000); }); + VLOG(1) << "Adding 0-5"; + pool->Schedule([queue] { PushRange(queue, 0, 5); }); + VLOG(1) << "Adding 5-7"; + pool->Schedule([queue] { PushRange(queue, 5, 7); }); +} + +// Pop elements from queue using Get(). Make sure that exactly elements +// were present and their values are all integers between 0 and high-1 +// inclusive. +void GetRange(IntQueue *queue, int high) { + VLOG(1) << "Testing Wait"; + std::vector results; + for (int i = 0; i != high; ++i) { + int r = queue->Get(); + VLOG(2) << "Waited and got " << r; + results.push_back(r); + } + CHECK_EQ(queue->count(), 0); + std::sort(results.begin(), results.end()); + for (int i = 0; i != high; ++i) { + CHECK(results[i] == i); + } +} + +// Pop elements from queue using TryGet(). Make sure that exactly +// elements were present and their values are all integers between 0 and high-1 +// inclusive. +void TryGetRange(IntQueue *queue, int high) { + std::vector results; + // Give up if we don't get all the elements back from the queue + // in 10 seconds. + int timeout = 10; + int r; + for (int i = 0; i != high; ++i) { + while (!queue->TryGet(&r)) { + if (!timeout--) { + LOG(FATAL) << "Can't find all elements in the queue"; + } + VLOG(1) << "Sleeping for a second..."; + sleep(1); + } + VLOG(2) << "Popped " << r; + results.push_back(r); + } + CHECK_EQ(queue->count(), 0); + CHECK(!queue->TryGet(&r)); + std::sort(results.begin(), results.end()); + for (int i = 0; i != high; ++i) { + CHECK_EQ(i, results[i]); + } +} + +const int kNumThreads = 15; + +TEST(ProducerConsumerQueue, GetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + PushRanges(&queue, &pool); + } + GetRange(&queue, 1000); +} + +TEST(ProducerConsumerQueue, TryGetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + PushRanges(&queue, &pool); + } + TryGetRange(&queue, 1000); +} + +TEST(ProducerConsumerQueue, ParallelGetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + pool.Schedule([&queue] { GetRange(&queue, 1000); }); + PushRanges(&queue, &pool); + } +} + +TEST(ProducerConsumerQueue, ParallelTryGetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + pool.Schedule([&queue] { TryGetRange(&queue, 1000); }); + PushRanges(&queue, &pool); + } +} + +} // namespace +} // namespace tensorflow -- 2.7.4