Move SummaryFileWriter alongside SummaryDbWriter
authorJustine Tunney <jart@google.com>
Fri, 12 Jan 2018 01:14:28 +0000 (17:14 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 12 Jan 2018 01:18:28 +0000 (17:18 -0800)
These impls are two peas in the same pod. This will make it easier to
write a follow-up change that refactors out common code.

PiperOrigin-RevId: 181684341

tensorflow/contrib/tensorboard/db/BUILD
tensorflow/contrib/tensorboard/db/summary_file_writer.cc [new file with mode: 0644]
tensorflow/contrib/tensorboard/db/summary_file_writer.h [new file with mode: 0644]
tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc [new file with mode: 0644]
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/summary_interface.cc [deleted file]
tensorflow/core/kernels/summary_interface.h
tensorflow/core/kernels/summary_interface_test.cc [deleted file]
tensorflow/core/kernels/summary_kernels.cc

index 4c9cc4ccd6e93151618d203a104217c90ad9a526..f4150673c7eef815b701e1e6014998a5fb1bf2a4 100644 (file)
@@ -60,6 +60,36 @@ tf_cc_test(
     ],
 )
 
+cc_library(
+    name = "summary_file_writer",
+    srcs = ["summary_file_writer.cc"],
+    hdrs = ["summary_file_writer.h"],
+    copts = tf_copts(),
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:ptr_util",
+        "//tensorflow/core/kernels:summary_interface",
+    ],
+)
+
+tf_cc_test(
+    name = "summary_file_writer_test",
+    size = "medium",  # file i/o
+    timeout = "short",
+    srcs = ["summary_file_writer_test.cc"],
+    deps = [
+        ":summary_file_writer",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
+
 filegroup(
     name = "all_files",
     srcs = glob(["*"]),
diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc
new file mode 100644 (file)
index 0000000..b4d379d
--- /dev/null
@@ -0,0 +1,462 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/tensorboard/db/summary_file_writer.h"
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/png/png_io.h"
+#include "tensorflow/core/lib/wav/wav_io.h"
+#include "tensorflow/core/util/events_writer.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace {
+
+template <typename T>
+Status TensorValueAt(Tensor t, int64 index, T* out) {
+  switch (t.dtype()) {
+    case DT_FLOAT:
+      *out = t.flat<float>()(index);
+      break;
+    case DT_DOUBLE:
+      *out = t.flat<double>()(index);
+      break;
+    case DT_HALF:
+      *out = T(t.flat<Eigen::half>()(index));
+      break;
+    case DT_INT32:
+      *out = t.flat<int32>()(index);
+      break;
+    case DT_UINT8:
+      *out = t.flat<uint8>()(index);
+      break;
+    case DT_INT16:
+      *out = t.flat<int16>()(index);
+      break;
+    case DT_INT8:
+      *out = t.flat<int8>()(index);
+      break;
+    case DT_BOOL:
+      *out = t.flat<bool>()(index);
+      break;
+    case DT_INT64:
+      *out = t.flat<int64>()(index);
+      break;
+    default:
+      return errors::Unimplemented("Scalar summary for dtype ",
+                                   DataTypeString(t.dtype()),
+                                   " is not supported.");
+  }
+  return Status::OK();
+}
+
+typedef Eigen::Tensor<uint8, 2, Eigen::RowMajor> Uint8Image;
+
+// Add the sequence of images specified by ith_image to the summary.
+//
+// Factoring this loop out into a helper function lets ith_image behave
+// differently in the float and uint8 cases: the float case needs a temporary
+// buffer which can be shared across calls to ith_image, but the uint8 case
+// does not.
+Status AddImages(const string& tag, int max_images, int batch_size, int w,
+                 int h, int depth,
+                 const std::function<Uint8Image(int)>& ith_image, Summary* s) {
+  const int N = std::min<int>(max_images, batch_size);
+  for (int i = 0; i < N; ++i) {
+    Summary::Value* v = s->add_value();
+    // The tag depends on the number of requested images (not the number
+    // produced.)
+    //
+    // Note that later on avisu uses "/" to figure out a consistent naming
+    // convention for display, so we append "/image" to guarantee that the
+    // image(s) won't be displayed in the global scope with no name.
+    if (max_images > 1) {
+      v->set_tag(strings::StrCat(tag, "/image/", i));
+    } else {
+      v->set_tag(strings::StrCat(tag, "/image"));
+    }
+
+    const auto image = ith_image(i);
+    Summary::Image* si = v->mutable_image();
+    si->set_height(h);
+    si->set_width(w);
+    si->set_colorspace(depth);
+    const int channel_bits = 8;
+    const int compression = -1;  // Use zlib default
+    if (!png::WriteImageToBuffer(image.data(), w, h, w * depth, depth,
+                                 channel_bits, compression,
+                                 si->mutable_encoded_image_string(), nullptr)) {
+      return errors::Internal("PNG encoding failed");
+    }
+  }
+  return Status::OK();
+}
+
+template <class T>
+void NormalizeFloatImage(int hw, int depth,
+                         typename TTypes<T>::ConstMatrix values,
+                         typename TTypes<uint8>::ConstVec bad_color,
+                         Uint8Image* image) {
+  if (!image->size()) return;  // Nothing to do for empty images
+
+  // Rescale the image to uint8 range.
+  //
+  // We are trying to generate an RGB image from a float/half tensor.  We do
+  // not have any info about the expected range of values in the tensor
+  // but the generated image needs to have all RGB values within [0, 255].
+  //
+  // We use two different algorithms to generate these values.  If the
+  // tensor has only positive values we scale them all by 255/max(values).
+  // If the tensor has both negative and positive values we scale them by
+  // the max of their absolute values and center them around 127.
+  //
+  // This works for most cases, but does not respect the relative dynamic
+  // range across different instances of the tensor.
+
+  // Compute min and max ignoring nonfinite pixels
+  float image_min = std::numeric_limits<float>::infinity();
+  float image_max = -image_min;
+  for (int i = 0; i < hw; i++) {
+    bool finite = true;
+    for (int j = 0; j < depth; j++) {
+      if (!Eigen::numext::isfinite(values(i, j))) {
+        finite = false;
+        break;
+      }
+    }
+    if (finite) {
+      for (int j = 0; j < depth; j++) {
+        float value(values(i, j));
+        image_min = std::min(image_min, value);
+        image_max = std::max(image_max, value);
+      }
+    }
+  }
+
+  // Pick an affine transform into uint8
+  const float kZeroThreshold = 1e-6;
+  T scale, offset;
+  if (image_min < 0) {
+    const float max_val = std::max(std::abs(image_min), std::abs(image_max));
+    scale = T(max_val < kZeroThreshold ? 0.0f : 127.0f / max_val);
+    offset = T(128.0f);
+  } else {
+    scale = T(image_max < kZeroThreshold ? 0.0f : 255.0f / image_max);
+    offset = T(0.0f);
+  }
+
+  // Transform image, turning nonfinite values to bad_color
+  for (int i = 0; i < hw; i++) {
+    bool finite = true;
+    for (int j = 0; j < depth; j++) {
+      if (!Eigen::numext::isfinite(values(i, j))) {
+        finite = false;
+        break;
+      }
+    }
+    if (finite) {
+      image->chip<0>(i) =
+          (values.template chip<0>(i) * scale + offset).template cast<uint8>();
+    } else {
+      image->chip<0>(i) = bad_color;
+    }
+  }
+}
+
+template <class T>
+Status NormalizeAndAddImages(const Tensor& tensor, int max_images, int h, int w,
+                             int hw, int depth, int batch_size,
+                             const string& base_tag, Tensor bad_color_tensor,
+                             Summary* s) {
+  // For float and half images, nans and infs are replaced with bad_color.
+  if (bad_color_tensor.dim_size(0) < depth) {
+    return errors::InvalidArgument(
+        "expected depth <= bad_color.size, got depth = ", depth,
+        ", bad_color.size = ", bad_color_tensor.dim_size(0));
+  }
+  auto bad_color_full = bad_color_tensor.vec<uint8>();
+  typename TTypes<uint8>::ConstVec bad_color(bad_color_full.data(), depth);
+
+  // Float images must be scaled and translated.
+  Uint8Image image(hw, depth);
+  auto ith_image = [&tensor, &image, bad_color, batch_size, hw, depth](int i) {
+    auto tensor_eigen = tensor.template shaped<T, 3>({batch_size, hw, depth});
+    typename TTypes<T>::ConstMatrix values(
+        &tensor_eigen(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
+    NormalizeFloatImage<T>(hw, depth, values, bad_color, &image);
+    return image;
+  };
+  return AddImages(base_tag, max_images, batch_size, w, h, depth, ith_image, s);
+}
+
+}  // namespace
+
+class SummaryFileWriter : public SummaryWriterInterface {
+ public:
+  SummaryFileWriter(int max_queue, int flush_millis, Env* env)
+      : SummaryWriterInterface(),
+        is_initialized_(false),
+        max_queue_(max_queue),
+        flush_millis_(flush_millis),
+        env_(env) {}
+
+  Status Initialize(const string& logdir, const string& filename_suffix) {
+    const Status is_dir = env_->IsDirectory(logdir);
+    if (!is_dir.ok()) {
+      if (is_dir.code() != tensorflow::error::NOT_FOUND) {
+        return is_dir;
+      }
+      TF_RETURN_IF_ERROR(env_->CreateDir(logdir));
+    }
+    mutex_lock ml(mu_);
+    events_writer_ =
+        tensorflow::MakeUnique<EventsWriter>(io::JoinPath(logdir, "events"));
+    if (!events_writer_->InitWithSuffix(filename_suffix)) {
+      return errors::Unknown("Could not initialize events writer.");
+    }
+    last_flush_ = env_->NowMicros();
+    is_initialized_ = true;
+    return Status::OK();
+  }
+
+  Status Flush() override {
+    mutex_lock ml(mu_);
+    if (!is_initialized_) {
+      return errors::FailedPrecondition("Class was not properly initialized.");
+    }
+    return InternalFlush();
+  }
+
+  ~SummaryFileWriter() override {
+    (void)Flush();  // Ignore errors.
+  }
+
+  Status WriteTensor(int64 global_step, Tensor t, const string& tag,
+                     const string& serialized_metadata) override {
+    std::unique_ptr<Event> e{new Event};
+    e->set_step(global_step);
+    e->set_wall_time(GetWallTime());
+    Summary::Value* v = e->mutable_summary()->add_value();
+    t.AsProtoTensorContent(v->mutable_tensor());
+    v->set_tag(tag);
+    if (!serialized_metadata.empty()) {
+      v->mutable_metadata()->ParseFromString(serialized_metadata);
+    }
+    return WriteEvent(std::move(e));
+  }
+
+  Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
+    std::unique_ptr<Event> e{new Event};
+    e->set_step(global_step);
+    e->set_wall_time(GetWallTime());
+    Summary::Value* v = e->mutable_summary()->add_value();
+    v->set_tag(tag);
+    float value;
+    TF_RETURN_IF_ERROR(TensorValueAt<float>(t, 0, &value));
+    v->set_simple_value(value);
+    return WriteEvent(std::move(e));
+  }
+
+  Status WriteHistogram(int64 global_step, Tensor t,
+                        const string& tag) override {
+    std::unique_ptr<Event> e{new Event};
+    e->set_step(global_step);
+    e->set_wall_time(GetWallTime());
+    Summary::Value* v = e->mutable_summary()->add_value();
+    v->set_tag(tag);
+    histogram::Histogram histo;
+    for (int64 i = 0; i < t.NumElements(); i++) {
+      double double_val;
+      TF_RETURN_IF_ERROR(TensorValueAt<double>(t, i, &double_val));
+      if (Eigen::numext::isnan(double_val)) {
+        return errors::InvalidArgument("Nan in summary histogram for: ", tag);
+      } else if (Eigen::numext::isinf(double_val)) {
+        return errors::InvalidArgument("Infinity in summary histogram for: ",
+                                       tag);
+      }
+      histo.Add(double_val);
+    }
+
+    histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
+    return WriteEvent(std::move(e));
+  }
+
+  Status WriteImage(int64 global_step, Tensor tensor, const string& tag,
+                    int max_images, Tensor bad_color) override {
+    if (!(tensor.dims() == 4 &&
+          (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 ||
+           tensor.dim_size(3) == 4))) {
+      return errors::InvalidArgument(
+          "Tensor must be 4-D with last dim 1, 3, or 4, not ",
+          tensor.shape().DebugString());
+    }
+    if (!(tensor.dim_size(0) < (1LL << 31) &&
+          tensor.dim_size(1) < (1LL << 31) &&
+          tensor.dim_size(2) < (1LL << 31) &&
+          (tensor.dim_size(1) * tensor.dim_size(2)) < (1LL << 29))) {
+      return errors::InvalidArgument("Tensor too large for summary ",
+                                     tensor.shape().DebugString());
+    }
+    std::unique_ptr<Event> e{new Event};
+    e->set_step(global_step);
+    e->set_wall_time(GetWallTime());
+    Summary* s = e->mutable_summary();
+    // The casts and h * w cannot overflow because of the limits above.
+    const int batch_size = static_cast<int>(tensor.dim_size(0));
+    const int h = static_cast<int>(tensor.dim_size(1));
+    const int w = static_cast<int>(tensor.dim_size(2));
+    const int hw = h * w;  // Compact these two dims for simplicity
+    const int depth = static_cast<int>(tensor.dim_size(3));
+    if (tensor.dtype() == DT_UINT8) {
+      // For uint8 input, no normalization is necessary
+      auto ith_image = [&tensor, batch_size, hw, depth](int i) {
+        auto values = tensor.shaped<uint8, 3>({batch_size, hw, depth});
+        return typename TTypes<uint8>::ConstMatrix(
+            &values(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
+      };
+      TF_RETURN_IF_ERROR(
+          AddImages(tag, max_images, batch_size, w, h, depth, ith_image, s));
+    } else if (tensor.dtype() == DT_HALF) {
+      TF_RETURN_IF_ERROR(NormalizeAndAddImages<Eigen::half>(
+          tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, s));
+    } else if (tensor.dtype() == DT_FLOAT) {
+      TF_RETURN_IF_ERROR(NormalizeAndAddImages<float>(
+          tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, s));
+    } else {
+      return errors::InvalidArgument(
+          "Only DT_INT8, DT_HALF, and DT_FLOAT images are supported. Got ",
+          DataTypeString(tensor.dtype()));
+    }
+
+    return WriteEvent(std::move(e));
+  }
+
+  Status WriteAudio(int64 global_step, Tensor tensor, const string& tag,
+                    int max_outputs, float sample_rate) override {
+    if (sample_rate <= 0.0f) {
+      return errors::InvalidArgument("sample_rate must be > 0");
+    }
+    const int batch_size = tensor.dim_size(0);
+    const int64 length_frames = tensor.dim_size(1);
+    const int64 num_channels =
+        tensor.dims() == 2 ? 1 : tensor.dim_size(tensor.dims() - 1);
+    std::unique_ptr<Event> e{new Event};
+    e->set_step(global_step);
+    e->set_wall_time(GetWallTime());
+    Summary* s = e->mutable_summary();
+    const int N = std::min<int>(max_outputs, batch_size);
+    for (int i = 0; i < N; ++i) {
+      Summary::Value* v = s->add_value();
+      if (max_outputs > 1) {
+        v->set_tag(strings::StrCat(tag, "/audio/", i));
+      } else {
+        v->set_tag(strings::StrCat(tag, "/audio"));
+      }
+
+      Summary::Audio* sa = v->mutable_audio();
+      sa->set_sample_rate(sample_rate);
+      sa->set_num_channels(num_channels);
+      sa->set_length_frames(length_frames);
+      sa->set_content_type("audio/wav");
+
+      auto values =
+          tensor.shaped<float, 3>({batch_size, length_frames, num_channels});
+      auto channels_by_frames = typename TTypes<float>::ConstMatrix(
+          &values(i, 0, 0),
+          Eigen::DSizes<Eigen::DenseIndex, 2>(length_frames, num_channels));
+      size_t sample_rate_truncated = lrintf(sample_rate);
+      if (sample_rate_truncated == 0) {
+        sample_rate_truncated = 1;
+      }
+      TF_RETURN_IF_ERROR(wav::EncodeAudioAsS16LEWav(
+          channels_by_frames.data(), sample_rate_truncated, num_channels,
+          length_frames, sa->mutable_encoded_audio_string()));
+    }
+    return WriteEvent(std::move(e));
+  }
+
+  Status WriteGraph(int64 global_step,
+                    std::unique_ptr<GraphDef> graph) override {
+    std::unique_ptr<Event> e{new Event};
+    e->set_step(global_step);
+    e->set_wall_time(GetWallTime());
+    graph->SerializeToString(e->mutable_graph_def());
+    return WriteEvent(std::move(e));
+  }
+
+  Status WriteEvent(std::unique_ptr<Event> event) override {
+    mutex_lock ml(mu_);
+    queue_.emplace_back(std::move(event));
+    if (queue_.size() >= max_queue_ ||
+        env_->NowMicros() - last_flush_ > 1000 * flush_millis_) {
+      return InternalFlush();
+    }
+    return Status::OK();
+  }
+
+  string DebugString() override { return "SummaryFileWriter"; }
+
+ private:
+  double GetWallTime() {
+    return static_cast<double>(env_->NowMicros()) / 1.0e6;
+  }
+
+  Status InternalFlush() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    for (const std::unique_ptr<Event>& e : queue_) {
+      events_writer_->WriteEvent(*e);
+    }
+    queue_.clear();
+    if (!events_writer_->Flush()) {
+      return errors::InvalidArgument("Could not flush events file.");
+    }
+    last_flush_ = env_->NowMicros();
+    return Status::OK();
+  }
+
+  bool is_initialized_;
+  const int max_queue_;
+  const int flush_millis_;
+  uint64 last_flush_;
+  Env* env_;
+  mutex mu_;
+  std::vector<std::unique_ptr<Event>> queue_ GUARDED_BY(mu_);
+  // A pointer to allow deferred construction.
+  std::unique_ptr<EventsWriter> events_writer_ GUARDED_BY(mu_);
+  std::vector<std::pair<string, SummaryMetadata>> registered_summaries_
+      GUARDED_BY(mu_);
+};
+
+Status CreateSummaryFileWriter(int max_queue, int flush_millis,
+                               const string& logdir,
+                               const string& filename_suffix, Env* env,
+                               SummaryWriterInterface** result) {
+  SummaryFileWriter* w = new SummaryFileWriter(max_queue, flush_millis, env);
+  const Status s = w->Initialize(logdir, filename_suffix);
+  if (!s.ok()) {
+    w->Unref();
+    *result = nullptr;
+    return s;
+  }
+  *result = w;
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer.h b/tensorflow/contrib/tensorboard/db/summary_file_writer.h
new file mode 100644 (file)
index 0000000..73b0a55
--- /dev/null
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_TENSORBOARD_DB_SUMMARY_FILE_WRITER_H_
+#define TENSORFLOW_CONTRIB_TENSORBOARD_DB_SUMMARY_FILE_WRITER_H_
+
+#include "tensorflow/core/kernels/summary_interface.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+/// \brief Creates SummaryWriterInterface which writes to a file.
+///
+/// The file is an append-only records file of tf.Event protos. That
+/// makes this summary writer suitable for file systems like GCS.
+///
+/// It will enqueue up to max_queue summaries, and flush at least every
+/// flush_millis milliseconds. The summaries will be written to the
+/// directory specified by logdir and with the filename suffixed by
+/// filename_suffix. The caller owns a reference to result if the
+/// returned status is ok. The Env object must not be destroyed until
+/// after the returned writer.
+Status CreateSummaryFileWriter(int max_queue, int flush_millis,
+                               const string& logdir,
+                               const string& filename_suffix, Env* env,
+                               SummaryWriterInterface** result);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CONTRIB_TENSORBOARD_DB_SUMMARY_FILE_WRITER_H_
diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc
new file mode 100644 (file)
index 0000000..c61b465
--- /dev/null
@@ -0,0 +1,216 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/tensorboard/db/summary_file_writer.h"
+
+#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/event.pb.h"
+
+namespace tensorflow {
+namespace {
+
+class FakeClockEnv : public EnvWrapper {
+ public:
+  FakeClockEnv() : EnvWrapper(Env::Default()), current_millis_(0) {}
+  void AdvanceByMillis(const uint64 millis) { current_millis_ += millis; }
+  uint64 NowMicros() override { return current_millis_ * 1000; }
+  uint64 NowSeconds() override { return current_millis_ * 1000; }
+
+ private:
+  uint64 current_millis_;
+};
+
+class SummaryFileWriterTest : public ::testing::Test {
+ protected:
+  Status SummaryTestHelper(
+      const string& test_name,
+      const std::function<Status(SummaryWriterInterface*)>& writer_fn,
+      const std::function<void(const Event&)>& test_fn) {
+    static std::set<string>* tests = new std::set<string>();
+    CHECK(tests->insert(test_name).second) << ": " << test_name;
+
+    SummaryWriterInterface* writer;
+    TF_CHECK_OK(CreateSummaryFileWriter(1, 1, testing::TmpDir(), test_name,
+                                        &env_, &writer));
+    core::ScopedUnref deleter(writer);
+
+    TF_CHECK_OK(writer_fn(writer));
+    TF_CHECK_OK(writer->Flush());
+
+    std::vector<string> files;
+    TF_CHECK_OK(env_.GetChildren(testing::TmpDir(), &files));
+    bool found = false;
+    for (const string& f : files) {
+      if (StringPiece(f).contains(test_name)) {
+        if (found) {
+          return errors::Unknown("Found more than one file for ", test_name);
+        }
+        found = true;
+        std::unique_ptr<RandomAccessFile> read_file;
+        TF_CHECK_OK(env_.NewRandomAccessFile(io::JoinPath(testing::TmpDir(), f),
+                                             &read_file));
+        io::RecordReader reader(read_file.get(), io::RecordReaderOptions());
+        string record;
+        uint64 offset = 0;
+        TF_CHECK_OK(
+            reader.ReadRecord(&offset,
+                              &record));  // The first event is irrelevant
+        TF_CHECK_OK(reader.ReadRecord(&offset, &record));
+        Event e;
+        e.ParseFromString(record);
+        test_fn(e);
+      }
+    }
+    if (!found) {
+      return errors::Unknown("Found no file for ", test_name);
+    }
+    return Status::OK();
+  }
+
+  FakeClockEnv env_;
+};
+
+TEST_F(SummaryFileWriterTest, WriteTensor) {
+  TF_CHECK_OK(SummaryTestHelper("tensor_test",
+                                [](SummaryWriterInterface* writer) {
+                                  Tensor one(DT_FLOAT, TensorShape({}));
+                                  one.scalar<float>()() = 1.0;
+                                  TF_RETURN_IF_ERROR(writer->WriteTensor(
+                                      2, one, "name",
+                                      SummaryMetadata().SerializeAsString()));
+                                  TF_RETURN_IF_ERROR(writer->Flush());
+                                  return Status::OK();
+                                },
+                                [](const Event& e) {
+                                  EXPECT_EQ(e.step(), 2);
+                                  CHECK_EQ(e.summary().value_size(), 1);
+                                  EXPECT_EQ(e.summary().value(0).tag(), "name");
+                                }));
+}
+
+TEST_F(SummaryFileWriterTest, WriteScalar) {
+  TF_CHECK_OK(SummaryTestHelper(
+      "scalar_test",
+      [](SummaryWriterInterface* writer) {
+        Tensor one(DT_FLOAT, TensorShape({}));
+        one.scalar<float>()() = 1.0;
+        TF_RETURN_IF_ERROR(writer->WriteScalar(2, one, "name"));
+        TF_RETURN_IF_ERROR(writer->Flush());
+        return Status::OK();
+      },
+      [](const Event& e) {
+        EXPECT_EQ(e.step(), 2);
+        CHECK_EQ(e.summary().value_size(), 1);
+        EXPECT_EQ(e.summary().value(0).tag(), "name");
+        EXPECT_EQ(e.summary().value(0).simple_value(), 1.0);
+      }));
+}
+
+TEST_F(SummaryFileWriterTest, WriteHistogram) {
+  TF_CHECK_OK(SummaryTestHelper("hist_test",
+                                [](SummaryWriterInterface* writer) {
+                                  Tensor one(DT_FLOAT, TensorShape({}));
+                                  one.scalar<float>()() = 1.0;
+                                  TF_RETURN_IF_ERROR(
+                                      writer->WriteHistogram(2, one, "name"));
+                                  TF_RETURN_IF_ERROR(writer->Flush());
+                                  return Status::OK();
+                                },
+                                [](const Event& e) {
+                                  EXPECT_EQ(e.step(), 2);
+                                  CHECK_EQ(e.summary().value_size(), 1);
+                                  EXPECT_EQ(e.summary().value(0).tag(), "name");
+                                  EXPECT_TRUE(e.summary().value(0).has_histo());
+                                }));
+}
+
+TEST_F(SummaryFileWriterTest, WriteImage) {
+  TF_CHECK_OK(SummaryTestHelper(
+      "image_test",
+      [](SummaryWriterInterface* writer) {
+        Tensor one(DT_UINT8, TensorShape({1, 1, 1, 1}));
+        one.scalar<int8>()() = 1;
+        TF_RETURN_IF_ERROR(writer->WriteImage(2, one, "name", 1, Tensor()));
+        TF_RETURN_IF_ERROR(writer->Flush());
+        return Status::OK();
+      },
+      [](const Event& e) {
+        EXPECT_EQ(e.step(), 2);
+        CHECK_EQ(e.summary().value_size(), 1);
+        EXPECT_EQ(e.summary().value(0).tag(), "name/image");
+        CHECK(e.summary().value(0).has_image());
+        EXPECT_EQ(e.summary().value(0).image().height(), 1);
+        EXPECT_EQ(e.summary().value(0).image().width(), 1);
+        EXPECT_EQ(e.summary().value(0).image().colorspace(), 1);
+      }));
+}
+
+TEST_F(SummaryFileWriterTest, WriteAudio) {
+  TF_CHECK_OK(SummaryTestHelper(
+      "audio_test",
+      [](SummaryWriterInterface* writer) {
+        Tensor one(DT_FLOAT, TensorShape({1, 1}));
+        one.scalar<float>()() = 1.0;
+        TF_RETURN_IF_ERROR(writer->WriteAudio(2, one, "name", 1, 1));
+        TF_RETURN_IF_ERROR(writer->Flush());
+        return Status::OK();
+      },
+      [](const Event& e) {
+        EXPECT_EQ(e.step(), 2);
+        CHECK_EQ(e.summary().value_size(), 1);
+        EXPECT_EQ(e.summary().value(0).tag(), "name/audio");
+        CHECK(e.summary().value(0).has_audio());
+      }));
+}
+
+TEST_F(SummaryFileWriterTest, WriteEvent) {
+  TF_CHECK_OK(
+      SummaryTestHelper("event_test",
+                        [](SummaryWriterInterface* writer) {
+                          std::unique_ptr<Event> e{new Event};
+                          e->set_step(7);
+                          e->mutable_summary()->add_value()->set_tag("hi");
+                          TF_RETURN_IF_ERROR(writer->WriteEvent(std::move(e)));
+                          TF_RETURN_IF_ERROR(writer->Flush());
+                          return Status::OK();
+                        },
+                        [](const Event& e) {
+                          EXPECT_EQ(e.step(), 7);
+                          CHECK_EQ(e.summary().value_size(), 1);
+                          EXPECT_EQ(e.summary().value(0).tag(), "hi");
+                        }));
+}
+
+TEST_F(SummaryFileWriterTest, WallTime) {
+  env_.AdvanceByMillis(7023);
+  TF_CHECK_OK(SummaryTestHelper(
+      "wall_time_test",
+      [](SummaryWriterInterface* writer) {
+        Tensor one(DT_FLOAT, TensorShape({}));
+        one.scalar<float>()() = 1.0;
+        TF_RETURN_IF_ERROR(writer->WriteScalar(2, one, "name"));
+        TF_RETURN_IF_ERROR(writer->Flush());
+        return Status::OK();
+      },
+      [](const Event& e) { EXPECT_EQ(e.wall_time(), 7.023); }));
+}
+
+}  // namespace
+}  // namespace tensorflow
index 878faf261bff022e86d693a848e1e314c4bd8b3a..1f3898de14fe4040e58d93d3f7ca40bd3318df0e 100644 (file)
@@ -5901,27 +5901,11 @@ tf_kernel_library(
 
 cc_library(
     name = "summary_interface",
-    srcs = ["summary_interface.cc"],
     hdrs = ["summary_interface.h"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:ptr_util",
-    ],
-)
-
-tf_cc_test(
-    name = "summary_interface_test",
-    srcs = ["summary_interface_test.cc"],
-    deps = [
-        ":summary_interface",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
     ],
 )
 
@@ -5929,9 +5913,9 @@ tf_kernel_library(
     name = "summary_kernels",
     srcs = ["summary_kernels.cc"],
     deps = [
-        ":summary_interface",
         "//tensorflow/contrib/tensorboard/db:schema",
         "//tensorflow/contrib/tensorboard/db:summary_db_writer",
+        "//tensorflow/contrib/tensorboard/db:summary_file_writer",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/core/kernels/summary_interface.cc b/tensorflow/core/kernels/summary_interface.cc
deleted file mode 100644 (file)
index 97c0c2c..0000000
+++ /dev/null
@@ -1,462 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/kernels/summary_interface.h"
-
-#include <utility>
-
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/framework/summary.pb.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/histogram/histogram.h"
-#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/png/png_io.h"
-#include "tensorflow/core/lib/wav/wav_io.h"
-#include "tensorflow/core/util/events_writer.h"
-#include "tensorflow/core/util/ptr_util.h"
-
-namespace tensorflow {
-namespace {
-template <typename T>
-Status TensorValueAt(Tensor t, int64 index, T* out) {
-  switch (t.dtype()) {
-    case DT_FLOAT:
-      *out = t.flat<float>()(index);
-      break;
-    case DT_DOUBLE:
-      *out = t.flat<double>()(index);
-      break;
-    case DT_HALF:
-      *out = T(t.flat<Eigen::half>()(index));
-      break;
-    case DT_INT32:
-      *out = t.flat<int32>()(index);
-      break;
-    case DT_UINT8:
-      *out = t.flat<uint8>()(index);
-      break;
-    case DT_INT16:
-      *out = t.flat<int16>()(index);
-      break;
-    case DT_INT8:
-      *out = t.flat<int8>()(index);
-      break;
-    case DT_BOOL:
-      *out = t.flat<bool>()(index);
-      break;
-    case DT_INT64:
-      *out = t.flat<int64>()(index);
-      break;
-    default:
-      return errors::Unimplemented("Scalar summary for dtype ",
-                                   DataTypeString(t.dtype()),
-                                   " is not supported.");
-  }
-  return Status::OK();
-}
-
-typedef Eigen::Tensor<uint8, 2, Eigen::RowMajor> Uint8Image;
-
-// Add the sequence of images specified by ith_image to the summary.
-//
-// Factoring this loop out into a helper function lets ith_image behave
-// differently in the float and uint8 cases: the float case needs a temporary
-// buffer which can be shared across calls to ith_image, but the uint8 case
-// does not.
-Status AddImages(const string& tag, int max_images, int batch_size, int w,
-                 int h, int depth,
-                 const std::function<Uint8Image(int)>& ith_image, Summary* s) {
-  const int N = std::min<int>(max_images, batch_size);
-  for (int i = 0; i < N; ++i) {
-    Summary::Value* v = s->add_value();
-    // The tag depends on the number of requested images (not the number
-    // produced.)
-    //
-    // Note that later on avisu uses "/" to figure out a consistent naming
-    // convention for display, so we append "/image" to guarantee that the
-    // image(s) won't be displayed in the global scope with no name.
-    if (max_images > 1) {
-      v->set_tag(strings::StrCat(tag, "/image/", i));
-    } else {
-      v->set_tag(strings::StrCat(tag, "/image"));
-    }
-
-    const auto image = ith_image(i);
-    Summary::Image* si = v->mutable_image();
-    si->set_height(h);
-    si->set_width(w);
-    si->set_colorspace(depth);
-    const int channel_bits = 8;
-    const int compression = -1;  // Use zlib default
-    if (!png::WriteImageToBuffer(image.data(), w, h, w * depth, depth,
-                                 channel_bits, compression,
-                                 si->mutable_encoded_image_string(), nullptr)) {
-      return errors::Internal("PNG encoding failed");
-    }
-  }
-  return Status::OK();
-}
-
-template <class T>
-void NormalizeFloatImage(int hw, int depth,
-                         typename TTypes<T>::ConstMatrix values,
-                         typename TTypes<uint8>::ConstVec bad_color,
-                         Uint8Image* image) {
-  if (!image->size()) return;  // Nothing to do for empty images
-
-  // Rescale the image to uint8 range.
-  //
-  // We are trying to generate an RGB image from a float/half tensor.  We do
-  // not have any info about the expected range of values in the tensor
-  // but the generated image needs to have all RGB values within [0, 255].
-  //
-  // We use two different algorithms to generate these values.  If the
-  // tensor has only positive values we scale them all by 255/max(values).
-  // If the tensor has both negative and positive values we scale them by
-  // the max of their absolute values and center them around 127.
-  //
-  // This works for most cases, but does not respect the relative dynamic
-  // range across different instances of the tensor.
-
-  // Compute min and max ignoring nonfinite pixels
-  float image_min = std::numeric_limits<float>::infinity();
-  float image_max = -image_min;
-  for (int i = 0; i < hw; i++) {
-    bool finite = true;
-    for (int j = 0; j < depth; j++) {
-      if (!Eigen::numext::isfinite(values(i, j))) {
-        finite = false;
-        break;
-      }
-    }
-    if (finite) {
-      for (int j = 0; j < depth; j++) {
-        float value(values(i, j));
-        image_min = std::min(image_min, value);
-        image_max = std::max(image_max, value);
-      }
-    }
-  }
-
-  // Pick an affine transform into uint8
-  const float kZeroThreshold = 1e-6;
-  T scale, offset;
-  if (image_min < 0) {
-    const float max_val = std::max(std::abs(image_min), std::abs(image_max));
-    scale = T(max_val < kZeroThreshold ? 0.0f : 127.0f / max_val);
-    offset = T(128.0f);
-  } else {
-    scale = T(image_max < kZeroThreshold ? 0.0f : 255.0f / image_max);
-    offset = T(0.0f);
-  }
-
-  // Transform image, turning nonfinite values to bad_color
-  for (int i = 0; i < hw; i++) {
-    bool finite = true;
-    for (int j = 0; j < depth; j++) {
-      if (!Eigen::numext::isfinite(values(i, j))) {
-        finite = false;
-        break;
-      }
-    }
-    if (finite) {
-      image->chip<0>(i) =
-          (values.template chip<0>(i) * scale + offset).template cast<uint8>();
-    } else {
-      image->chip<0>(i) = bad_color;
-    }
-  }
-}
-
-template <class T>
-Status NormalizeAndAddImages(const Tensor& tensor, int max_images, int h, int w,
-                             int hw, int depth, int batch_size,
-                             const string& base_tag, Tensor bad_color_tensor,
-                             Summary* s) {
-  // For float and half images, nans and infs are replaced with bad_color.
-  if (bad_color_tensor.dim_size(0) < depth) {
-    return errors::InvalidArgument(
-        "expected depth <= bad_color.size, got depth = ", depth,
-        ", bad_color.size = ", bad_color_tensor.dim_size(0));
-  }
-  auto bad_color_full = bad_color_tensor.vec<uint8>();
-  typename TTypes<uint8>::ConstVec bad_color(bad_color_full.data(), depth);
-
-  // Float images must be scaled and translated.
-  Uint8Image image(hw, depth);
-  auto ith_image = [&tensor, &image, bad_color, batch_size, hw, depth](int i) {
-    auto tensor_eigen = tensor.template shaped<T, 3>({batch_size, hw, depth});
-    typename TTypes<T>::ConstMatrix values(
-        &tensor_eigen(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
-    NormalizeFloatImage<T>(hw, depth, values, bad_color, &image);
-    return image;
-  };
-  return AddImages(base_tag, max_images, batch_size, w, h, depth, ith_image, s);
-}
-
-}  // namespace
-
-class SummaryWriterImpl : public SummaryWriterInterface {
- public:
-  SummaryWriterImpl(int max_queue, int flush_millis, Env* env)
-      : SummaryWriterInterface(),
-        is_initialized_(false),
-        max_queue_(max_queue),
-        flush_millis_(flush_millis),
-        env_(env) {}
-
-  Status Initialize(const string& logdir, const string& filename_suffix) {
-    const Status is_dir = env_->IsDirectory(logdir);
-    if (!is_dir.ok()) {
-      if (is_dir.code() != tensorflow::error::NOT_FOUND) {
-        return is_dir;
-      }
-      TF_RETURN_IF_ERROR(env_->CreateDir(logdir));
-    }
-    mutex_lock ml(mu_);
-    events_writer_ =
-        tensorflow::MakeUnique<EventsWriter>(io::JoinPath(logdir, "events"));
-    if (!events_writer_->InitWithSuffix(filename_suffix)) {
-      return errors::Unknown("Could not initialize events writer.");
-    }
-    last_flush_ = env_->NowMicros();
-    is_initialized_ = true;
-    return Status::OK();
-  }
-
-  Status Flush() override {
-    mutex_lock ml(mu_);
-    if (!is_initialized_) {
-      return errors::FailedPrecondition("Class was not properly initialized.");
-    }
-    return InternalFlush();
-  }
-
-  ~SummaryWriterImpl() override {
-    (void)Flush();  // Ignore errors.
-  }
-
-  Status WriteTensor(int64 global_step, Tensor t, const string& tag,
-                     const string& serialized_metadata) override {
-    std::unique_ptr<Event> e{new Event};
-    e->set_step(global_step);
-    e->set_wall_time(GetWallTime());
-    Summary::Value* v = e->mutable_summary()->add_value();
-    t.AsProtoTensorContent(v->mutable_tensor());
-    v->set_tag(tag);
-    if (!serialized_metadata.empty()) {
-      v->mutable_metadata()->ParseFromString(serialized_metadata);
-    }
-    return WriteEvent(std::move(e));
-  }
-
-  Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
-    std::unique_ptr<Event> e{new Event};
-    e->set_step(global_step);
-    e->set_wall_time(GetWallTime());
-    Summary::Value* v = e->mutable_summary()->add_value();
-    v->set_tag(tag);
-    float value;
-    TF_RETURN_IF_ERROR(TensorValueAt<float>(t, 0, &value));
-    v->set_simple_value(value);
-    return WriteEvent(std::move(e));
-  }
-
-  Status WriteHistogram(int64 global_step, Tensor t,
-                        const string& tag) override {
-    std::unique_ptr<Event> e{new Event};
-    e->set_step(global_step);
-    e->set_wall_time(GetWallTime());
-    Summary::Value* v = e->mutable_summary()->add_value();
-    v->set_tag(tag);
-    histogram::Histogram histo;
-    for (int64 i = 0; i < t.NumElements(); i++) {
-      double double_val;
-      TF_RETURN_IF_ERROR(TensorValueAt<double>(t, i, &double_val));
-      if (Eigen::numext::isnan(double_val)) {
-        return errors::InvalidArgument("Nan in summary histogram for: ", tag);
-      } else if (Eigen::numext::isinf(double_val)) {
-        return errors::InvalidArgument("Infinity in summary histogram for: ",
-                                       tag);
-      }
-      histo.Add(double_val);
-    }
-
-    histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
-    return WriteEvent(std::move(e));
-  }
-
-  Status WriteImage(int64 global_step, Tensor tensor, const string& tag,
-                    int max_images, Tensor bad_color) override {
-    if (!(tensor.dims() == 4 &&
-          (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 ||
-           tensor.dim_size(3) == 4))) {
-      return errors::InvalidArgument(
-          "Tensor must be 4-D with last dim 1, 3, or 4, not ",
-          tensor.shape().DebugString());
-    }
-    if (!(tensor.dim_size(0) < (1LL << 31) &&
-          tensor.dim_size(1) < (1LL << 31) &&
-          tensor.dim_size(2) < (1LL << 31) &&
-          (tensor.dim_size(1) * tensor.dim_size(2)) < (1LL << 29))) {
-      return errors::InvalidArgument("Tensor too large for summary ",
-                                     tensor.shape().DebugString());
-    }
-    std::unique_ptr<Event> e{new Event};
-    e->set_step(global_step);
-    e->set_wall_time(GetWallTime());
-    Summary* s = e->mutable_summary();
-    // The casts and h * w cannot overflow because of the limits above.
-    const int batch_size = static_cast<int>(tensor.dim_size(0));
-    const int h = static_cast<int>(tensor.dim_size(1));
-    const int w = static_cast<int>(tensor.dim_size(2));
-    const int hw = h * w;  // Compact these two dims for simplicity
-    const int depth = static_cast<int>(tensor.dim_size(3));
-    if (tensor.dtype() == DT_UINT8) {
-      // For uint8 input, no normalization is necessary
-      auto ith_image = [&tensor, batch_size, hw, depth](int i) {
-        auto values = tensor.shaped<uint8, 3>({batch_size, hw, depth});
-        return typename TTypes<uint8>::ConstMatrix(
-            &values(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
-      };
-      TF_RETURN_IF_ERROR(
-          AddImages(tag, max_images, batch_size, w, h, depth, ith_image, s));
-    } else if (tensor.dtype() == DT_HALF) {
-      TF_RETURN_IF_ERROR(NormalizeAndAddImages<Eigen::half>(
-          tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, s));
-    } else if (tensor.dtype() == DT_FLOAT) {
-      TF_RETURN_IF_ERROR(NormalizeAndAddImages<float>(
-          tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, s));
-    } else {
-      return errors::InvalidArgument(
-          "Only DT_INT8, DT_HALF, and DT_FLOAT images are supported. Got ",
-          DataTypeString(tensor.dtype()));
-    }
-
-    return WriteEvent(std::move(e));
-  }
-
-  Status WriteAudio(int64 global_step, Tensor tensor, const string& tag,
-                    int max_outputs, float sample_rate) override {
-    if (sample_rate <= 0.0f) {
-      return errors::InvalidArgument("sample_rate must be > 0");
-    }
-    const int batch_size = tensor.dim_size(0);
-    const int64 length_frames = tensor.dim_size(1);
-    const int64 num_channels =
-        tensor.dims() == 2 ? 1 : tensor.dim_size(tensor.dims() - 1);
-    std::unique_ptr<Event> e{new Event};
-    e->set_step(global_step);
-    e->set_wall_time(GetWallTime());
-    Summary* s = e->mutable_summary();
-    const int N = std::min<int>(max_outputs, batch_size);
-    for (int i = 0; i < N; ++i) {
-      Summary::Value* v = s->add_value();
-      if (max_outputs > 1) {
-        v->set_tag(strings::StrCat(tag, "/audio/", i));
-      } else {
-        v->set_tag(strings::StrCat(tag, "/audio"));
-      }
-
-      Summary::Audio* sa = v->mutable_audio();
-      sa->set_sample_rate(sample_rate);
-      sa->set_num_channels(num_channels);
-      sa->set_length_frames(length_frames);
-      sa->set_content_type("audio/wav");
-
-      auto values =
-          tensor.shaped<float, 3>({batch_size, length_frames, num_channels});
-      auto channels_by_frames = typename TTypes<float>::ConstMatrix(
-          &values(i, 0, 0),
-          Eigen::DSizes<Eigen::DenseIndex, 2>(length_frames, num_channels));
-      size_t sample_rate_truncated = lrintf(sample_rate);
-      if (sample_rate_truncated == 0) {
-        sample_rate_truncated = 1;
-      }
-      TF_RETURN_IF_ERROR(wav::EncodeAudioAsS16LEWav(
-          channels_by_frames.data(), sample_rate_truncated, num_channels,
-          length_frames, sa->mutable_encoded_audio_string()));
-    }
-    return WriteEvent(std::move(e));
-  }
-
-  Status WriteGraph(int64 global_step,
-                    std::unique_ptr<GraphDef> graph) override {
-    std::unique_ptr<Event> e{new Event};
-    e->set_step(global_step);
-    e->set_wall_time(GetWallTime());
-    graph->SerializeToString(e->mutable_graph_def());
-    return WriteEvent(std::move(e));
-  }
-
-  Status WriteEvent(std::unique_ptr<Event> event) override {
-    mutex_lock ml(mu_);
-    queue_.emplace_back(std::move(event));
-    if (queue_.size() >= max_queue_ ||
-        env_->NowMicros() - last_flush_ > 1000 * flush_millis_) {
-      return InternalFlush();
-    }
-    return Status::OK();
-  }
-
-  string DebugString() override { return "SummaryWriterImpl"; }
-
- private:
-  double GetWallTime() {
-    return static_cast<double>(env_->NowMicros()) / 1.0e6;
-  }
-
-  Status InternalFlush() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-    for (const std::unique_ptr<Event>& e : queue_) {
-      events_writer_->WriteEvent(*e);
-    }
-    queue_.clear();
-    if (!events_writer_->Flush()) {
-      return errors::InvalidArgument("Could not flush events file.");
-    }
-    last_flush_ = env_->NowMicros();
-    return Status::OK();
-  }
-
-  bool is_initialized_;
-  const int max_queue_;
-  const int flush_millis_;
-  uint64 last_flush_;
-  Env* env_;
-  mutex mu_;
-  std::vector<std::unique_ptr<Event>> queue_ GUARDED_BY(mu_);
-  // A pointer to allow deferred construction.
-  std::unique_ptr<EventsWriter> events_writer_ GUARDED_BY(mu_);
-  std::vector<std::pair<string, SummaryMetadata>> registered_summaries_
-      GUARDED_BY(mu_);
-};
-
-Status CreateSummaryWriter(int max_queue, int flush_millis,
-                           const string& logdir, const string& filename_suffix,
-                           Env* env, SummaryWriterInterface** result) {
-  SummaryWriterImpl* w = new SummaryWriterImpl(max_queue, flush_millis, env);
-  const Status s = w->Initialize(logdir, filename_suffix);
-  if (!s.ok()) {
-    w->Unref();
-    *result = nullptr;
-    return s;
-  }
-  *result = w;
-  return Status::OK();
-}
-
-}  // namespace tensorflow
index da1c28709fb35372b1f0b28faba757a23bcd9ac4..02391e967a84b2d2ff015d541969163807b9adc2 100644 (file)
@@ -19,6 +19,8 @@ limitations under the License.
 
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/util/event.pb.h"
 
 namespace tensorflow {
@@ -53,16 +55,6 @@ class SummaryWriterInterface : public ResourceBase {
   virtual Status WriteEvent(std::unique_ptr<Event> e) = 0;
 };
 
-// Creates a SummaryWriterInterface instance which writes to a file. It will
-// enqueue up to max_queue summaries, and flush at least every flush_millis
-// milliseconds. The summaries will be written to the directory specified by
-// logdir and with the filename suffixed by filename_suffix. The caller owns a
-// reference to result if the returned status is ok. The Env object must not
-// be destroyed until after the returned writer.
-Status CreateSummaryWriter(int max_queue, int flush_millis,
-                           const string& logdir, const string& filename_suffix,
-                           Env* env, SummaryWriterInterface** result);
-
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_
diff --git a/tensorflow/core/kernels/summary_interface_test.cc b/tensorflow/core/kernels/summary_interface_test.cc
deleted file mode 100644 (file)
index 58e021a..0000000
+++ /dev/null
@@ -1,216 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/kernels/summary_interface.h"
-
-#include "tensorflow/core/framework/summary.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/refcount.h"
-#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/io/record_reader.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/util/event.pb.h"
-
-namespace tensorflow {
-namespace {
-
-class FakeClockEnv : public EnvWrapper {
- public:
-  FakeClockEnv() : EnvWrapper(Env::Default()), current_millis_(0) {}
-  void AdvanceByMillis(const uint64 millis) { current_millis_ += millis; }
-  uint64 NowMicros() override { return current_millis_ * 1000; }
-  uint64 NowSeconds() override { return current_millis_ * 1000; }
-
- private:
-  uint64 current_millis_;
-};
-
-class SummaryInterfaceTest : public ::testing::Test {
- protected:
-  Status SummaryTestHelper(
-      const string& test_name,
-      const std::function<Status(SummaryWriterInterface*)>& writer_fn,
-      const std::function<void(const Event&)>& test_fn) {
-    static std::set<string>* tests = new std::set<string>();
-    CHECK(tests->insert(test_name).second) << ": " << test_name;
-
-    SummaryWriterInterface* writer;
-    TF_CHECK_OK(CreateSummaryWriter(1, 1, testing::TmpDir(), test_name, &env_,
-                                    &writer));
-    core::ScopedUnref deleter(writer);
-
-    TF_CHECK_OK(writer_fn(writer));
-    TF_CHECK_OK(writer->Flush());
-
-    std::vector<string> files;
-    TF_CHECK_OK(env_.GetChildren(testing::TmpDir(), &files));
-    bool found = false;
-    for (const string& f : files) {
-      if (StringPiece(f).contains(test_name)) {
-        if (found) {
-          return errors::Unknown("Found more than one file for ", test_name);
-        }
-        found = true;
-        std::unique_ptr<RandomAccessFile> read_file;
-        TF_CHECK_OK(env_.NewRandomAccessFile(io::JoinPath(testing::TmpDir(), f),
-                                             &read_file));
-        io::RecordReader reader(read_file.get(), io::RecordReaderOptions());
-        string record;
-        uint64 offset = 0;
-        TF_CHECK_OK(
-            reader.ReadRecord(&offset,
-                              &record));  // The first event is irrelevant
-        TF_CHECK_OK(reader.ReadRecord(&offset, &record));
-        Event e;
-        e.ParseFromString(record);
-        test_fn(e);
-      }
-    }
-    if (!found) {
-      return errors::Unknown("Found no file for ", test_name);
-    }
-    return Status::OK();
-  }
-
-  FakeClockEnv env_;
-};
-
-TEST_F(SummaryInterfaceTest, WriteTensor) {
-  TF_CHECK_OK(SummaryTestHelper("tensor_test",
-                                [](SummaryWriterInterface* writer) {
-                                  Tensor one(DT_FLOAT, TensorShape({}));
-                                  one.scalar<float>()() = 1.0;
-                                  TF_RETURN_IF_ERROR(writer->WriteTensor(
-                                      2, one, "name",
-                                      SummaryMetadata().SerializeAsString()));
-                                  TF_RETURN_IF_ERROR(writer->Flush());
-                                  return Status::OK();
-                                },
-                                [](const Event& e) {
-                                  EXPECT_EQ(e.step(), 2);
-                                  CHECK_EQ(e.summary().value_size(), 1);
-                                  EXPECT_EQ(e.summary().value(0).tag(), "name");
-                                }));
-}
-
-TEST_F(SummaryInterfaceTest, WriteScalar) {
-  TF_CHECK_OK(SummaryTestHelper(
-      "scalar_test",
-      [](SummaryWriterInterface* writer) {
-        Tensor one(DT_FLOAT, TensorShape({}));
-        one.scalar<float>()() = 1.0;
-        TF_RETURN_IF_ERROR(writer->WriteScalar(2, one, "name"));
-        TF_RETURN_IF_ERROR(writer->Flush());
-        return Status::OK();
-      },
-      [](const Event& e) {
-        EXPECT_EQ(e.step(), 2);
-        CHECK_EQ(e.summary().value_size(), 1);
-        EXPECT_EQ(e.summary().value(0).tag(), "name");
-        EXPECT_EQ(e.summary().value(0).simple_value(), 1.0);
-      }));
-}
-
-TEST_F(SummaryInterfaceTest, WriteHistogram) {
-  TF_CHECK_OK(SummaryTestHelper("hist_test",
-                                [](SummaryWriterInterface* writer) {
-                                  Tensor one(DT_FLOAT, TensorShape({}));
-                                  one.scalar<float>()() = 1.0;
-                                  TF_RETURN_IF_ERROR(
-                                      writer->WriteHistogram(2, one, "name"));
-                                  TF_RETURN_IF_ERROR(writer->Flush());
-                                  return Status::OK();
-                                },
-                                [](const Event& e) {
-                                  EXPECT_EQ(e.step(), 2);
-                                  CHECK_EQ(e.summary().value_size(), 1);
-                                  EXPECT_EQ(e.summary().value(0).tag(), "name");
-                                  EXPECT_TRUE(e.summary().value(0).has_histo());
-                                }));
-}
-
-TEST_F(SummaryInterfaceTest, WriteImage) {
-  TF_CHECK_OK(SummaryTestHelper(
-      "image_test",
-      [](SummaryWriterInterface* writer) {
-        Tensor one(DT_UINT8, TensorShape({1, 1, 1, 1}));
-        one.scalar<int8>()() = 1;
-        TF_RETURN_IF_ERROR(writer->WriteImage(2, one, "name", 1, Tensor()));
-        TF_RETURN_IF_ERROR(writer->Flush());
-        return Status::OK();
-      },
-      [](const Event& e) {
-        EXPECT_EQ(e.step(), 2);
-        CHECK_EQ(e.summary().value_size(), 1);
-        EXPECT_EQ(e.summary().value(0).tag(), "name/image");
-        CHECK(e.summary().value(0).has_image());
-        EXPECT_EQ(e.summary().value(0).image().height(), 1);
-        EXPECT_EQ(e.summary().value(0).image().width(), 1);
-        EXPECT_EQ(e.summary().value(0).image().colorspace(), 1);
-      }));
-}
-
-TEST_F(SummaryInterfaceTest, WriteAudio) {
-  TF_CHECK_OK(SummaryTestHelper(
-      "audio_test",
-      [](SummaryWriterInterface* writer) {
-        Tensor one(DT_FLOAT, TensorShape({1, 1}));
-        one.scalar<float>()() = 1.0;
-        TF_RETURN_IF_ERROR(writer->WriteAudio(2, one, "name", 1, 1));
-        TF_RETURN_IF_ERROR(writer->Flush());
-        return Status::OK();
-      },
-      [](const Event& e) {
-        EXPECT_EQ(e.step(), 2);
-        CHECK_EQ(e.summary().value_size(), 1);
-        EXPECT_EQ(e.summary().value(0).tag(), "name/audio");
-        CHECK(e.summary().value(0).has_audio());
-      }));
-}
-
-TEST_F(SummaryInterfaceTest, WriteEvent) {
-  TF_CHECK_OK(
-      SummaryTestHelper("event_test",
-                        [](SummaryWriterInterface* writer) {
-                          std::unique_ptr<Event> e{new Event};
-                          e->set_step(7);
-                          e->mutable_summary()->add_value()->set_tag("hi");
-                          TF_RETURN_IF_ERROR(writer->WriteEvent(std::move(e)));
-                          TF_RETURN_IF_ERROR(writer->Flush());
-                          return Status::OK();
-                        },
-                        [](const Event& e) {
-                          EXPECT_EQ(e.step(), 7);
-                          CHECK_EQ(e.summary().value_size(), 1);
-                          EXPECT_EQ(e.summary().value(0).tag(), "hi");
-                        }));
-}
-
-TEST_F(SummaryInterfaceTest, WallTime) {
-  env_.AdvanceByMillis(7023);
-  TF_CHECK_OK(SummaryTestHelper(
-      "wall_time_test",
-      [](SummaryWriterInterface* writer) {
-        Tensor one(DT_FLOAT, TensorShape({}));
-        one.scalar<float>()() = 1.0;
-        TF_RETURN_IF_ERROR(writer->WriteScalar(2, one, "name"));
-        TF_RETURN_IF_ERROR(writer->Flush());
-        return Status::OK();
-      },
-      [](const Event& e) { EXPECT_EQ(e.wall_time(), 7.023); }));
-}
-
-}  // namespace
-}  // namespace tensorflow
index a815f540b10a3d6bc4cc98f39d72796c85734d84..41cbece1d648f3e2dba112375e494d2ed8192db9 100644 (file)
@@ -15,10 +15,10 @@ limitations under the License.
 
 #include "tensorflow/contrib/tensorboard/db/schema.h"
 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
+#include "tensorflow/contrib/tensorboard/db/summary_file_writer.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/kernels/summary_interface.h"
 #include "tensorflow/core/lib/db/sqlite.h"
 #include "tensorflow/core/platform/protobuf.h"
 
@@ -43,8 +43,9 @@ class CreateSummaryFileWriterOp : public OpKernel {
     OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
     const string filename_suffix = tmp->scalar<string>()();
     SummaryWriterInterface* s;
-    OP_REQUIRES_OK(ctx, CreateSummaryWriter(max_queue, flush_millis, logdir,
-                                            filename_suffix, ctx->env(), &s));
+    OP_REQUIRES_OK(ctx,
+                   CreateSummaryFileWriter(max_queue, flush_millis, logdir,
+                                           filename_suffix, ctx->env(), &s));
     OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s));
   }
 };