From: Shivani Agrawal Date: Fri, 15 Dec 2017 20:23:55 +0000 (-0800) Subject: [tf.data] Saveable iterator for LatencyStatsDataset. X-Git-Tag: v1.6.0-rc0~428^2~3^2~107 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ed24130f90c2c45db0473df3e9158d4895ce326b;p=platform%2Fupstream%2Ftensorflow.git [tf.data] Saveable iterator for LatencyStatsDataset. PiperOrigin-RevId: 179225632 --- diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 2b04b278ba..07bdf92044 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -224,6 +224,34 @@ class StatsDatasetSerializationTest( lambda: self._build_dataset_bytes_stats(num_outputs), lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) + def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag)) + + def _build_dataset_multiple_tags(self, + num_elements, + tag1="record_latency", + tag2="record_latency_2"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + + def testLatencyStatsDatasetSaveableCore(self): + num_outputs = 100 + + self.run_core_tests( + lambda: self._build_dataset_latency_stats(num_outputs), + lambda: self._build_dataset_latency_stats(num_outputs // 10), + num_outputs) + + self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), + None, num_outputs) + + tag1 = "record_latency" + tag2 = "record_latency" + self.run_core_tests( + lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), + None, num_outputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index 09704d4b25..8742e6c55f 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -43,14 +43,14 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { DatasetBase** output) override { string tag; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag)); - *output = new Dataset(input, std::move(tag)); + *output = new Dataset(ctx, input, std::move(tag)); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - explicit Dataset(const DatasetBase* input, string tag) - : input_(input), tag_(std::move(tag)) { + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag) + : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) { input_->Ref(); } @@ -71,6 +71,17 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "LatencyStatsDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_node; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node)); + Node* tag_node; + TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator { public: @@ -81,6 +92,7 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { + tf_shared_lock l(mu_); uint64 start = ctx->env()->NowMicros(); Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); uint64 end = ctx->env()->NowMicros(); @@ -92,8 +104,23 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { return s; } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + return Status::OK(); + } + private: - const std::unique_ptr input_impl_; + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); }; const DatasetBase* const input_;