From: Shivani Agrawal Date: Thu, 26 Apr 2018 22:24:44 +0000 (-0700) Subject: [tf.data] Adds support for adding scalar value to `StatsAggregator`. X-Git-Tag: upstream/v1.9.0_rc1~206^2~1^2~15 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2808c3f05f7713ff1ab20f365e986a4651180376;p=platform%2Fupstream%2Ftensorflow.git [tf.data] Adds support for adding scalar value to `StatsAggregator`. PiperOrigin-RevId: 194463407 --- diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h index a449f32..8002d92 100644 --- a/tensorflow/core/framework/stats_aggregator.h +++ b/tensorflow/core/framework/stats_aggregator.h @@ -47,6 +47,10 @@ class StatsAggregator { virtual void AddToHistogram(const string& name, gtl::ArraySlice values) = 0; + // TODO(shivaniagarawal): consistency in double and float usage. + // Add the given `value` as Scalar with the given `name`. + virtual void AddScalar(const string& name, float value) = 0; + // Stores a protocol buffer representation of the aggregator state in the // given `out_summary`. // TODO(mrry): Consider separating this method from the `StatsAggregator` diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc index dd37311..33a56b2 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc @@ -38,6 +38,11 @@ class StatsAggregatorImpl : public StatsAggregator { } } + void AddScalar(const string& name, float value) override { + mutex_lock l(mu_); + scalars_[name] = value; + } + void EncodeToProto(Summary* out_summary) override { mutex_lock l(mu_); for (const auto& pair : histograms_) { @@ -49,11 +54,17 @@ class StatsAggregatorImpl : public StatsAggregator { histogram.EncodeToProto(value->mutable_histo(), false /* doesn't preserve zero buckets */); } + for (const auto& pair : scalars_) { + Summary::Value* value = out_summary->add_value(); + value->set_tag(pair.first); + value->set_simple_value(pair.second); + } } private: mutex mu_; std::unordered_map histograms_ GUARDED_BY(mu_); + std::unordered_map scalars_ GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImpl); };