[tf.data] Saveable iterator for LatencyStatsDataset.
authorShivani Agrawal <shivaniagrawal@google.com>
Fri, 15 Dec 2017 20:23:55 +0000 (12:23 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 15 Dec 2017 20:27:45 +0000 (12:27 -0800)
PiperOrigin-RevId: 179225632

tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
tensorflow/core/kernels/data/stats_dataset_ops.cc

index 2b04b278bae0659cf9b2b1cfd7c55bb1a80dfd07..07bdf920446e953c2a1abaf495d2e9e1256106fd 100644 (file)
@@ -224,6 +224,34 @@ class StatsDatasetSerializationTest(
         lambda: self._build_dataset_bytes_stats(num_outputs),
         lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
 
+  def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
+    return dataset_ops.Dataset.range(num_elements).apply(
+        stats_ops.latency_stats(tag))
+
+  def _build_dataset_multiple_tags(self,
+                                   num_elements,
+                                   tag1="record_latency",
+                                   tag2="record_latency_2"):
+    return dataset_ops.Dataset.range(num_elements).apply(
+        stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
+
+  def testLatencyStatsDatasetSaveableCore(self):
+    num_outputs = 100
+
+    self.run_core_tests(
+        lambda: self._build_dataset_latency_stats(num_outputs),
+        lambda: self._build_dataset_latency_stats(num_outputs // 10),
+        num_outputs)
+
+    self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
+                        None, num_outputs)
+
+    tag1 = "record_latency"
+    tag2 = "record_latency"
+    self.run_core_tests(
+        lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
+        None, num_outputs)
+
 
 if __name__ == "__main__":
   test.main()
index 09704d4b25fdef85c729617f96d2e84e86aa4ff8..8742e6c55f059f303a47c3c7fae6576b2b897e37 100644 (file)
@@ -43,14 +43,14 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
                    DatasetBase** output) override {
     string tag;
     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
-    *output = new Dataset(input, std::move(tag));
+    *output = new Dataset(ctx, input, std::move(tag));
   }
 
  private:
-  class Dataset : public DatasetBase {
+  class Dataset : public GraphDatasetBase {
    public:
-    explicit Dataset(const DatasetBase* input, string tag)
-        : input_(input), tag_(std::move(tag)) {
+    explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
+        : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
       input_->Ref();
     }
 
@@ -71,6 +71,17 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
 
     string DebugString() override { return "LatencyStatsDatasetOp::Dataset"; }
 
+   protected:
+    Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      Node* input_node;
+      TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
+      Node* tag_node;
+      TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
+      TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output));
+      return Status::OK();
+    }
+
    private:
     class Iterator : public DatasetIterator<Dataset> {
      public:
@@ -81,6 +92,7 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
+        tf_shared_lock l(mu_);
         uint64 start = ctx->env()->NowMicros();
         Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
         uint64 end = ctx->env()->NowMicros();
@@ -92,8 +104,23 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
         return s;
       }
 
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        return Status::OK();
+      }
+
+      Status RestoreInternal(OpKernelContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        return Status::OK();
+      }
+
      private:
-      const std::unique_ptr<IteratorBase> input_impl_;
+      mutex mu_;
+      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
     };
 
     const DatasetBase* const input_;