[tf.data] Adds support for adding scalar value to `StatsAggregator`.
authorShivani Agrawal <shivaniagrawal@google.com>
Thu, 26 Apr 2018 22:24:44 +0000 (15:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 22:27:12 +0000 (15:27 -0700)
PiperOrigin-RevId: 194463407

tensorflow/core/framework/stats_aggregator.h
tensorflow/core/kernels/data/stats_aggregator_ops.cc

index a449f32..8002d92 100644 (file)
@@ -47,6 +47,10 @@ class StatsAggregator {
   virtual void AddToHistogram(const string& name,
                               gtl::ArraySlice<double> values) = 0;
 
+  // TODO(shivaniagarawal): consistency in double and float usage.
+  // Add the given `value` as Scalar with the given `name`.
+  virtual void AddScalar(const string& name, float value) = 0;
+
   // Stores a protocol buffer representation of the aggregator state in the
   // given `out_summary`.
   // TODO(mrry): Consider separating this method from the `StatsAggregator`
index dd37311..33a56b2 100644 (file)
@@ -38,6 +38,11 @@ class StatsAggregatorImpl : public StatsAggregator {
     }
   }
 
+  void AddScalar(const string& name, float value) override {
+    mutex_lock l(mu_);
+    scalars_[name] = value;
+  }
+
   void EncodeToProto(Summary* out_summary) override {
     mutex_lock l(mu_);
     for (const auto& pair : histograms_) {
@@ -49,11 +54,17 @@ class StatsAggregatorImpl : public StatsAggregator {
       histogram.EncodeToProto(value->mutable_histo(),
                               false /* doesn't preserve zero buckets */);
     }
+    for (const auto& pair : scalars_) {
+      Summary::Value* value = out_summary->add_value();
+      value->set_tag(pair.first);
+      value->set_simple_value(pair.second);
+    }
   }
 
  private:
   mutex mu_;
   std::unordered_map<string, histogram::Histogram> histograms_ GUARDED_BY(mu_);
+  std::unordered_map<string, float> scalars_ GUARDED_BY(mu_);
   TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImpl);
 };