Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
+ mutex_lock external_l(external_mu_);
mutex_lock l(mu_);
EnsureRunnerThreadStarted(ctx);
BatchResult* result = &batch_results_[ComputeIndex(input_batch_)];
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock external_l(external_mu_);
mutex_lock l(mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
+ mutex_lock external_l(external_mu_);
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
return Status::OK();
}
+ // Used for coordination between the main thread, the runner thread, and
+ // the callback threads.
mutex mu_;
// Used for coordination between the main thread, the runner thread, and
// the callback threads. In particular, the runner thread should only
// user specified level of parallelism and there are slots available in
// the `batch_results_` buffer.
condition_variable cond_var_;
+ // Used for serializing external parallelism.
+ mutex external_mu_ ACQUIRED_BEFORE(mu_);
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.