Merged commit includes the following changes:
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 3 Mar 2018 02:33:21 +0000 (18:33 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 4 Mar 2018 17:58:29 +0000 (09:58 -0800)
187697531  by andrewharp:

    Tweak whitespace for fft2d dep.

--
187696129  by A. Unique TensorFlower:

    Generalize support for logical expressions, comparison operators and multiple comparisons.

--
187692494  by vinuraja:

    * Adds a boolean attribute to ConfigureDistributedTPUOp for internal use.

    * Adds GraphRunner ctor which takes in the device to run the graph on.

--
187692129  by andrewharp:

    Audio utility classes for supporting MFCC and AudioSpectrogram operators

--

PiperOrigin-RevId: 187697531

15 files changed:
tensorflow/contrib/lite/kernels/internal/BUILD
tensorflow/contrib/lite/kernels/internal/mfcc.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/internal/mfcc.h [new file with mode: 0644]
tensorflow/contrib/lite/kernels/internal/mfcc_dct.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/internal/mfcc_dct.h [new file with mode: 0644]
tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h [new file with mode: 0644]
tensorflow/contrib/lite/kernels/internal/spectrogram.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/internal/spectrogram.h [new file with mode: 0644]
tensorflow/contrib/py2tf/converters/logical_expressions.py
tensorflow/contrib/py2tf/converters/logical_expressions_test.py
tensorflow/contrib/py2tf/impl/conversion.py
tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc
tensorflow/core/common_runtime/graph_runner.cc
tensorflow/core/common_runtime/graph_runner.h

index 6ccad3b..d5dd2cb 100644 (file)
@@ -309,6 +309,27 @@ cc_library(
     ],
 )
 
+# Audio support classes imported directly from TensorFlow.
+cc_library(
+    name = "audio_utils",
+    srcs = [
+        "mfcc.cc",
+        "mfcc_dct.cc",
+        "mfcc_mel_filterbank.cc",
+        "spectrogram.cc",
+    ],
+    hdrs = [
+        "mfcc.h",
+        "mfcc_dct.h",
+        "mfcc_mel_filterbank.h",
+        "spectrogram.h",
+    ],
+    deps = [
+        "//third_party/fft2d:fft2d_headers",
+        "@fft2d",
+    ],
+)
+
 cc_library(
     name = "tensor_utils",
     srcs = [
diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc.cc b/tensorflow/contrib/lite/kernels/internal/mfcc.cc
new file mode 100644 (file)
index 0000000..eafe0c7
--- /dev/null
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <math.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/mfcc.h"
+
+namespace tflite {
+namespace internal {
+
+const double kDefaultUpperFrequencyLimit = 4000;
+const double kDefaultLowerFrequencyLimit = 20;
+const double kFilterbankFloor = 1e-12;
+const int kDefaultFilterbankChannelCount = 40;
+const int kDefaultDCTCoefficientCount = 13;
+
+Mfcc::Mfcc()
+    : initialized_(false),
+      lower_frequency_limit_(kDefaultLowerFrequencyLimit),
+      upper_frequency_limit_(kDefaultUpperFrequencyLimit),
+      filterbank_channel_count_(kDefaultFilterbankChannelCount),
+      dct_coefficient_count_(kDefaultDCTCoefficientCount) {}
+
+bool Mfcc::Initialize(int input_length, double input_sample_rate) {
+  bool initialized = mel_filterbank_.Initialize(
+      input_length, input_sample_rate, filterbank_channel_count_,
+      lower_frequency_limit_, upper_frequency_limit_);
+  initialized &=
+      dct_.Initialize(filterbank_channel_count_, dct_coefficient_count_);
+  initialized_ = initialized;
+  return initialized;
+}
+
+void Mfcc::Compute(const std::vector<double>& spectrogram_frame,
+                   std::vector<double>* output) const {
+  if (!initialized_) {
+    // LOG(ERROR) << "Mfcc not initialized.";
+    return;
+  }
+  std::vector<double> working;
+  mel_filterbank_.Compute(spectrogram_frame, &working);
+  for (int i = 0; i < working.size(); ++i) {
+    double val = working[i];
+    if (val < kFilterbankFloor) {
+      val = kFilterbankFloor;
+    }
+    working[i] = log(val);
+  }
+  dct_.Compute(working, output);
+}
+
+}  // namespace internal
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc.h b/tensorflow/contrib/lite/kernels/internal/mfcc.h
new file mode 100644 (file)
index 0000000..d8500ec
--- /dev/null
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Basic class for computing MFCCs from spectrogram slices.
+
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_H_
+
+#include <vector>
+
+#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h"
+#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h"
+
+namespace tflite {
+namespace internal {
+
+class Mfcc {
+ public:
+  Mfcc();
+  bool Initialize(int input_length, double input_sample_rate);
+
+  // Input is a single squared-magnitude spectrogram frame. The input spectrum
+  // is converted to linear magnitude and weighted into bands using a
+  // triangular mel filterbank, and a discrete cosine transform (DCT) of the
+  // values is taken. Output is populated with the lowest dct_coefficient_count
+  // of these values.
+  void Compute(const std::vector<double>& spectrogram_frame,
+               std::vector<double>* output) const;
+
+  void set_upper_frequency_limit(double upper_frequency_limit) {
+    // CHECK(!initialized_) << "Set frequency limits before calling
+    // Initialize.";
+    upper_frequency_limit_ = upper_frequency_limit;
+  }
+
+  void set_lower_frequency_limit(double lower_frequency_limit) {
+    // CHECK(!initialized_) << "Set frequency limits before calling
+    // Initialize.";
+    lower_frequency_limit_ = lower_frequency_limit;
+  }
+
+  void set_filterbank_channel_count(int filterbank_channel_count) {
+    /// CHECK(!initialized_) << "Set channel count before calling Initialize.";
+    filterbank_channel_count_ = filterbank_channel_count;
+  }
+
+  void set_dct_coefficient_count(int dct_coefficient_count) {
+    // CHECK(!initialized_) << "Set coefficient count before calling
+    // Initialize.";
+    dct_coefficient_count_ = dct_coefficient_count;
+  }
+
+ private:
+  MfccMelFilterbank mel_filterbank_;
+  MfccDct dct_;
+  bool initialized_;
+  double lower_frequency_limit_;
+  double upper_frequency_limit_;
+  int filterbank_channel_count_;
+  int dct_coefficient_count_;
+};
+
+}  // namespace internal
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_dct.cc b/tensorflow/contrib/lite/kernels/internal/mfcc_dct.cc
new file mode 100644 (file)
index 0000000..b0b7d18
--- /dev/null
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h"
+
+#include <math.h>
+
+namespace tflite {
+namespace internal {
+
+MfccDct::MfccDct() : initialized_(false) {}
+
+bool MfccDct::Initialize(int input_length, int coefficient_count) {
+  coefficient_count_ = coefficient_count;
+  input_length_ = input_length;
+
+  if (coefficient_count_ < 1) {
+    return false;
+  }
+
+  if (input_length < 1) {
+    return false;
+  }
+
+  if (coefficient_count_ > input_length_) {
+    return false;
+  }
+
+  cosines_.resize(coefficient_count_);
+  double fnorm = sqrt(2.0 / input_length_);
+  // Some platforms don't have M_PI, so define a local constant here.
+  const double pi = atan(1) * 4;
+  double arg = pi / input_length_;
+  for (int i = 0; i < coefficient_count_; ++i) {
+    cosines_[i].resize(input_length_);
+    for (int j = 0; j < input_length_; ++j) {
+      cosines_[i][j] = fnorm * cos(i * arg * (j + 0.5));
+    }
+  }
+  initialized_ = true;
+  return true;
+}
+
+void MfccDct::Compute(const std::vector<double> &input,
+                      std::vector<double> *output) const {
+  if (!initialized_) {
+    return;
+  }
+
+  output->resize(coefficient_count_);
+  int length = input.size();
+  if (length > input_length_) {
+    length = input_length_;
+  }
+
+  for (int i = 0; i < coefficient_count_; ++i) {
+    double sum = 0.0;
+    for (int j = 0; j < length; ++j) {
+      sum += cosines_[i][j] * input[j];
+    }
+    (*output)[i] = sum;
+  }
+}
+
+}  // namespace internal
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_dct.h b/tensorflow/contrib/lite/kernels/internal/mfcc_dct.h
new file mode 100644 (file)
index 0000000..a53f5cb
--- /dev/null
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Basic minimal DCT class for MFCC speech processing.
+
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_DCT_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_DCT_H_
+
+#include <vector>
+
+namespace tflite {
+namespace internal {
+
+class MfccDct {
+ public:
+  MfccDct();
+  bool Initialize(int input_length, int coefficient_count);
+  void Compute(const std::vector<double>& input,
+               std::vector<double>* output) const;
+
+ private:
+  bool initialized_;
+  int coefficient_count_;
+  int input_length_;
+  std::vector<std::vector<double> > cosines_;
+};
+
+}  // namespace internal
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_DCT_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.cc b/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.cc
new file mode 100644 (file)
index 0000000..c3deb33
--- /dev/null
@@ -0,0 +1,204 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This code resamples the FFT bins, and smooths then with triangle-shaped
+// weights to create a mel-frequency filter bank. For filter i centered at f_i,
+// there is a triangular weighting of the FFT bins that extends from
+// filter f_i-1 (with a value of zero at the left edge of the triangle) to f_i
+// (where the filter value is 1) to f_i+1 (where the filter values returns to
+// zero).
+
+// Note: this code fails if you ask for too many channels.  The algorithm used
+// here assumes that each FFT bin contributes to at most two channels: the
+// right side of a triangle for channel i, and the left side of the triangle
+// for channel i+1.  If you ask for so many channels that some of the
+// resulting mel triangle filters are smaller than a single FFT bin, these
+// channels may end up with no contributing FFT bins.  The resulting mel
+// spectrum output will have some channels that are always zero.
+
+#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h"
+
+#include <math.h>
+
+namespace tflite {
+namespace internal {
+
+MfccMelFilterbank::MfccMelFilterbank() : initialized_(false) {}
+
+bool MfccMelFilterbank::Initialize(int input_length, double input_sample_rate,
+                                   int output_channel_count,
+                                   double lower_frequency_limit,
+                                   double upper_frequency_limit) {
+  num_channels_ = output_channel_count;
+  sample_rate_ = input_sample_rate;
+  input_length_ = input_length;
+
+  if (num_channels_ < 1) {
+    // LOG(ERROR) << "Number of filterbank channels must be positive.";
+    return false;
+  }
+
+  if (sample_rate_ <= 0) {
+    // LOG(ERROR) << "Sample rate must be positive.";
+    return false;
+  }
+
+  if (input_length < 2) {
+    // LOG(ERROR) << "Input length must greater than 1.";
+    return false;
+  }
+
+  if (lower_frequency_limit < 0) {
+    // LOG(ERROR) << "Lower frequency limit must be nonnegative.";
+    return false;
+  }
+
+  if (upper_frequency_limit <= lower_frequency_limit) {
+    /// LOG(ERROR) << "Upper frequency limit must be greater than "
+    //           << "lower frequency limit.";
+    return false;
+  }
+
+  // An extra center frequency is computed at the top to get the upper
+  // limit on the high side of the final triangular filter.
+  center_frequencies_.resize(num_channels_ + 1);
+  const double mel_low = FreqToMel(lower_frequency_limit);
+  const double mel_hi = FreqToMel(upper_frequency_limit);
+  const double mel_span = mel_hi - mel_low;
+  const double mel_spacing = mel_span / static_cast<double>(num_channels_ + 1);
+  for (int i = 0; i < num_channels_ + 1; ++i) {
+    center_frequencies_[i] = mel_low + (mel_spacing * (i + 1));
+  }
+
+  // Always exclude DC; emulate HTK.
+  const double hz_per_sbin =
+      0.5 * sample_rate_ / static_cast<double>(input_length_ - 1);
+  start_index_ = static_cast<int>(1.5 + (lower_frequency_limit / hz_per_sbin));
+  end_index_ = static_cast<int>(upper_frequency_limit / hz_per_sbin);
+
+  // Maps the input spectrum bin indices to filter bank channels/indices. For
+  // each FFT bin, band_mapper tells us which channel this bin contributes to
+  // on the right side of the triangle.  Thus this bin also contributes to the
+  // left side of the next channel's triangle response.
+  band_mapper_.resize(input_length_);
+  int channel = 0;
+  for (int i = 0; i < input_length_; ++i) {
+    double melf = FreqToMel(i * hz_per_sbin);
+    if ((i < start_index_) || (i > end_index_)) {
+      band_mapper_[i] = -2;  // Indicate an unused Fourier coefficient.
+    } else {
+      while ((center_frequencies_[channel] < melf) &&
+             (channel < num_channels_)) {
+        ++channel;
+      }
+      band_mapper_[i] = channel - 1;  // Can be == -1
+    }
+  }
+
+  // Create the weighting functions to taper the band edges.  The contribution
+  // of any one FFT bin is based on its distance along the continuum between two
+  // mel-channel center frequencies.  This bin contributes weights_[i] to the
+  // current channel and 1-weights_[i] to the next channel.
+  weights_.resize(input_length_);
+  for (int i = 0; i < input_length_; ++i) {
+    channel = band_mapper_[i];
+    if ((i < start_index_) || (i > end_index_)) {
+      weights_[i] = 0.0;
+    } else {
+      if (channel >= 0) {
+        weights_[i] =
+            (center_frequencies_[channel + 1] - FreqToMel(i * hz_per_sbin)) /
+            (center_frequencies_[channel + 1] - center_frequencies_[channel]);
+      } else {
+        weights_[i] = (center_frequencies_[0] - FreqToMel(i * hz_per_sbin)) /
+                      (center_frequencies_[0] - mel_low);
+      }
+    }
+  }
+  // Check the sum of FFT bin weights for every mel band to identify
+  // situations where the mel bands are so narrow that they don't get
+  // significant weight on enough (or any) FFT bins -- i.e., too many
+  // mel bands have been requested for the given FFT size.
+  std::vector<int> bad_channels;
+  for (int c = 0; c < num_channels_; ++c) {
+    float band_weights_sum = 0.0;
+    for (int i = 0; i < input_length_; ++i) {
+      if (band_mapper_[i] == c - 1) {
+        band_weights_sum += (1.0 - weights_[i]);
+      } else if (band_mapper_[i] == c) {
+        band_weights_sum += weights_[i];
+      }
+    }
+    // The lowest mel channels have the fewest FFT bins and the lowest
+    // weights sum.  But given that the target gain at the center frequency
+    // is 1.0, if the total sum of weights is 0.5, we're in bad shape.
+    if (band_weights_sum < 0.5) {
+      bad_channels.push_back(c);
+    }
+  }
+  if (!bad_channels.empty()) {
+    /*
+    LOG(ERROR) << "Missing " << bad_channels.size() << " bands "
+               << " starting at " << bad_channels[0]
+               << " in mel-frequency design. "
+               << "Perhaps too many channels or "
+               << "not enough frequency resolution in spectrum. ("
+               << "input_length: " << input_length
+               << " input_sample_rate: " << input_sample_rate
+               << " output_channel_count: " << output_channel_count
+               << " lower_frequency_limit: " << lower_frequency_limit
+               << " upper_frequency_limit: " << upper_frequency_limit;
+               */
+  }
+  initialized_ = true;
+  return true;
+}
+
+// Compute the mel spectrum from the squared-magnitude FFT input by taking the
+// square root, then summing FFT magnitudes under triangular integration windows
+// whose widths increase with frequency.
+void MfccMelFilterbank::Compute(const std::vector<double> &input,
+                                std::vector<double> *output) const {
+  if (!initialized_) {
+    // LOG(ERROR) << "Mel Filterbank not initialized.";
+    return;
+  }
+
+  if (input.size() <= end_index_) {
+    // LOG(ERROR) << "Input too short to compute filterbank";
+    return;
+  }
+
+  // Ensure output is right length and reset all values.
+  output->assign(num_channels_, 0.0);
+
+  for (int i = start_index_; i <= end_index_; i++) {  // For each FFT bin
+    double spec_val = sqrt(input[i]);
+    double weighted = spec_val * weights_[i];
+    int channel = band_mapper_[i];
+    if (channel >= 0)
+      (*output)[channel] += weighted;  // Right side of triangle, downward slope
+    channel++;
+    if (channel < num_channels_)
+      (*output)[channel] += spec_val - weighted;  // Left side of triangle
+  }
+}
+
+double MfccMelFilterbank::FreqToMel(double freq) const {
+  return 1127.0 * log(1.0 + (freq / 700.0));
+}
+
+}  // namespace internal
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h b/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h
new file mode 100644 (file)
index 0000000..c1db282
--- /dev/null
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Basic class for applying a mel-scale mapping to a power spectrum.
+
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_
+
+#include <vector>
+
+namespace tflite {
+namespace internal {
+
+class MfccMelFilterbank {
+ public:
+  MfccMelFilterbank();
+  bool Initialize(int input_length,  // Number of unique FFT bins fftsize/2+1.
+                  double input_sample_rate, int output_channel_count,
+                  double lower_frequency_limit, double upper_frequency_limit);
+
+  // Takes a squared-magnitude spectrogram slice as input, computes a
+  // triangular-mel-weighted linear-magnitude filterbank, and places the result
+  // in output.
+  void Compute(const std::vector<double>& input,
+               std::vector<double>* output) const;
+
+ private:
+  double FreqToMel(double freq) const;
+  bool initialized_;
+  int num_channels_;
+  double sample_rate_;
+  int input_length_;
+  std::vector<double> center_frequencies_;  // In mel, for each mel channel.
+
+  // Each FFT bin b contributes to two triangular mel channels, with
+  // proportion weights_[b] going into mel channel band_mapper_[b], and
+  // proportion (1 - weights_[b]) going into channel band_mapper_[b] + 1.
+  // Thus, weights_ contains the weighting applied to each FFT bin for the
+  // upper-half of the triangular band.
+  std::vector<double> weights_;  // Right-side weight for this fft  bin.
+
+  // FFT bin i contributes to the upper side of mel channel band_mapper_[i]
+  std::vector<int> band_mapper_;
+  int start_index_;  // Lowest FFT bin used to calculate mel spectrum.
+  int end_index_;    // Highest FFT bin used to calculate mel spectrum.
+};
+
+}  // namespace internal
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc
new file mode 100644 (file)
index 0000000..66ca694
--- /dev/null
@@ -0,0 +1,244 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h"
+
+#include <math.h>
+
+#include "third_party/fft2d/fft.h"
+
+namespace tflite {
+namespace internal {
+
+using std::complex;
+
+namespace {
+// Returns the default Hann window function for the spectrogram.
+void GetPeriodicHann(int window_length, std::vector<double>* window) {
+  // Some platforms don't have M_PI, so define a local constant here.
+  const double pi = std::atan(1) * 4;
+  window->resize(window_length);
+  for (int i = 0; i < window_length; ++i) {
+    (*window)[i] = 0.5 - 0.5 * cos((2 * pi * i) / window_length);
+  }
+}
+}  // namespace
+
+bool Spectrogram::Initialize(int window_length, int step_length) {
+  std::vector<double> window;
+  GetPeriodicHann(window_length, &window);
+  return Initialize(window, step_length);
+}
+
+inline int Log2Floor(uint n) {
+  if (n == 0) return -1;
+  int log = 0;
+  uint value = n;
+  for (int i = 4; i >= 0; --i) {
+    int shift = (1 << i);
+    uint x = value >> shift;
+    if (x != 0) {
+      value = x;
+      log += shift;
+    }
+  }
+  assert(value == 1);
+  return log;
+}
+
+inline int Log2Ceiling(uint n) {
+  int floor = Log2Floor(n);
+  if (n == (n & ~(n - 1)))  // zero or a power of two
+    return floor;
+  else
+    return floor + 1;
+}
+
+inline uint NextPowerOfTwo(uint value) {
+  int exponent = Log2Ceiling(value);
+  // DCHECK_LT(exponent, std::numeric_limits<uint32>::digits);
+  return 1 << exponent;
+}
+
+bool Spectrogram::Initialize(const std::vector<double>& window,
+                             int step_length) {
+  window_length_ = window.size();
+  window_ = window;  // Copy window.
+  if (window_length_ < 2) {
+    // LOG(ERROR) << "Window length too short.";
+    initialized_ = false;
+    return false;
+  }
+
+  step_length_ = step_length;
+  if (step_length_ < 1) {
+    // LOG(ERROR) << "Step length must be positive.";
+    initialized_ = false;
+    return false;
+  }
+
+  fft_length_ = NextPowerOfTwo(window_length_);
+  // CHECK(fft_length_ >= window_length_);
+  output_frequency_channels_ = 1 + fft_length_ / 2;
+
+  // Allocate 2 more than what rdft needs, so we can rationalize the layout.
+  fft_input_output_.assign(fft_length_ + 2, 0.0);
+
+  int half_fft_length = fft_length_ / 2;
+  fft_double_working_area_.assign(half_fft_length, 0.0);
+  fft_integer_working_area_.assign(2 + static_cast<int>(sqrt(half_fft_length)),
+                                   0);
+  // Set flag element to ensure that the working areas are initialized
+  // on the first call to cdft.  It's redundant given the assign above,
+  // but keep it as a reminder.
+  fft_integer_working_area_[0] = 0;
+  input_queue_.clear();
+  samples_to_next_step_ = window_length_;
+  initialized_ = true;
+  return true;
+}
+
+template <class InputSample, class OutputSample>
+bool Spectrogram::ComputeComplexSpectrogram(
+    const std::vector<InputSample>& input,
+    std::vector<std::vector<complex<OutputSample>>>* output) {
+  if (!initialized_) {
+    // LOG(ERROR) << "ComputeComplexSpectrogram() called before successful call
+    // "
+    //           << "to Initialize().";
+    return false;
+  }
+  // CHECK(output);
+  output->clear();
+  int input_start = 0;
+  while (GetNextWindowOfSamples(input, &input_start)) {
+    // DCHECK_EQ(input_queue_.size(), window_length_);
+    ProcessCoreFFT();  // Processes input_queue_ to fft_input_output_.
+    // Add a new slice vector onto the output, to save new result to.
+    output->resize(output->size() + 1);
+    // Get a reference to the newly added slice to fill in.
+    auto& spectrogram_slice = output->back();
+    spectrogram_slice.resize(output_frequency_channels_);
+    for (int i = 0; i < output_frequency_channels_; ++i) {
+      // This will convert double to float if it needs to.
+      spectrogram_slice[i] = complex<OutputSample>(
+          fft_input_output_[2 * i], fft_input_output_[2 * i + 1]);
+    }
+  }
+  return true;
+}
+// Instantiate it four ways:
+template bool Spectrogram::ComputeComplexSpectrogram(
+    const std::vector<float>& input, std::vector<std::vector<complex<float>>>*);
+template bool Spectrogram::ComputeComplexSpectrogram(
+    const std::vector<double>& input,
+    std::vector<std::vector<complex<float>>>*);
+template bool Spectrogram::ComputeComplexSpectrogram(
+    const std::vector<float>& input,
+    std::vector<std::vector<complex<double>>>*);
+template bool Spectrogram::ComputeComplexSpectrogram(
+    const std::vector<double>& input,
+    std::vector<std::vector<complex<double>>>*);
+
+template <class InputSample, class OutputSample>
+bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
+    const std::vector<InputSample>& input,
+    std::vector<std::vector<OutputSample>>* output) {
+  if (!initialized_) {
+    // LOG(ERROR) << "ComputeSquaredMagnitudeSpectrogram() called before "
+    //           << "successful call to Initialize().";
+    return false;
+  }
+  // CHECK(output);
+  output->clear();
+  int input_start = 0;
+  while (GetNextWindowOfSamples(input, &input_start)) {
+    // DCHECK_EQ(input_queue_.size(), window_length_);
+    ProcessCoreFFT();  // Processes input_queue_ to fft_input_output_.
+    // Add a new slice vector onto the output, to save new result to.
+    output->resize(output->size() + 1);
+    // Get a reference to the newly added slice to fill in.
+    auto& spectrogram_slice = output->back();
+    spectrogram_slice.resize(output_frequency_channels_);
+    for (int i = 0; i < output_frequency_channels_; ++i) {
+      // Similar to the Complex case, except storing the norm.
+      // But the norm function is known to be a performance killer,
+      // so do it this way with explicit real and imagninary temps.
+      const double re = fft_input_output_[2 * i];
+      const double im = fft_input_output_[2 * i + 1];
+      // Which finally converts double to float if it needs to.
+      spectrogram_slice[i] = re * re + im * im;
+    }
+  }
+  return true;
+}
+// Instantiate it four ways:
+template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
+    const std::vector<float>& input, std::vector<std::vector<float>>*);
+template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
+    const std::vector<double>& input, std::vector<std::vector<float>>*);
+template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
+    const std::vector<float>& input, std::vector<std::vector<double>>*);
+template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
+    const std::vector<double>& input, std::vector<std::vector<double>>*);
+
+// Return true if a full window of samples is prepared; manage the queue.
+template <class InputSample>
+bool Spectrogram::GetNextWindowOfSamples(const std::vector<InputSample>& input,
+                                         int* input_start) {
+  auto input_it = input.begin() + *input_start;
+  int input_remaining = input.end() - input_it;
+  if (samples_to_next_step_ > input_remaining) {
+    // Copy in as many samples are left and return false, no full window.
+    input_queue_.insert(input_queue_.end(), input_it, input.end());
+    *input_start += input_remaining;  // Increases it to input.size().
+    samples_to_next_step_ -= input_remaining;
+    return false;  // Not enough for a full window.
+  } else {
+    // Copy just enough into queue to make a new window, then trim the
+    // front off the queue to make it window-sized.
+    input_queue_.insert(input_queue_.end(), input_it,
+                        input_it + samples_to_next_step_);
+    *input_start += samples_to_next_step_;
+    input_queue_.erase(
+        input_queue_.begin(),
+        input_queue_.begin() + input_queue_.size() - window_length_);
+    // DCHECK_EQ(window_length_, input_queue_.size());
+    samples_to_next_step_ = step_length_;  // Be ready for next time.
+    return true;  // Yes, input_queue_ now contains exactly a window-full.
+  }
+}
+
+void Spectrogram::ProcessCoreFFT() {
+  for (int j = 0; j < window_length_; ++j) {
+    fft_input_output_[j] = input_queue_[j] * window_[j];
+  }
+  // Zero-pad the rest of the input buffer.
+  for (int j = window_length_; j < fft_length_; ++j) {
+    fft_input_output_[j] = 0.0;
+  }
+  const int kForwardFFT = 1;  // 1 means forward; -1 reverse.
+  // This real FFT is a fair amount faster than using cdft here.
+  rdft(fft_length_, kForwardFFT, &fft_input_output_[0],
+       &fft_integer_working_area_[0], &fft_double_working_area_[0]);
+  // Make rdft result look like cdft result;
+  // unpack the last real value from the first position's imag slot.
+  fft_input_output_[fft_length_] = fft_input_output_[1];
+  fft_input_output_[fft_length_ + 1] = 0;
+  fft_input_output_[1] = 0;
+}
+
+}  // namespace internal
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.h b/tensorflow/contrib/lite/kernels/internal/spectrogram.h
new file mode 100644 (file)
index 0000000..b77a68f
--- /dev/null
@@ -0,0 +1,110 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Class for generating spectrogram slices from a waveform.
+// Initialize() should be called before calls to other functions.  Once
+// Initialize() has been called and returned true, The Compute*() functions can
+// be called repeatedly with sequential input data (ie. the first element of the
+// next input vector directly follows the last element of the previous input
+// vector). Whenever enough audio samples are buffered to produce a
+// new frame, it will be placed in output. Output is cleared on each
+// call to Compute*(). This class is thread-unsafe, and should only be
+// called from one thread at a time.
+// With the default parameters, the output of this class should be very
+// close to the results of the following MATLAB code:
+// overlap_samples = window_length_samples - step_samples;
+// window = hann(window_length_samples, 'periodic');
+// S = abs(spectrogram(audio, window, overlap_samples)).^2;
+
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_
+
+#include <complex>
+#include <deque>
+#include <vector>
+
+#include "third_party/fft2d/fft.h"
+
+namespace tflite {
+namespace internal {
+
+class Spectrogram {
+ public:
+  Spectrogram() : initialized_(false) {}
+  ~Spectrogram() {}
+
+  // Initializes the class with a given window length and step length
+  // (both in samples). Internally a Hann window is used as the window
+  // function. Returns true on success, after which calls to Process()
+  // are possible. window_length must be greater than 1 and step
+  // length must be greater than 0.
+  bool Initialize(int window_length, int step_length);
+
+  // Initialize with an explicit window instead of a length.
+  bool Initialize(const std::vector<double>& window, int step_length);
+
+  // Processes an arbitrary amount of audio data (contained in input)
+  // to yield complex spectrogram frames. After a successful call to
+  // Initialize(), Process() may be called repeatedly with new input data
+  // each time.  The audio input is buffered internally, and the output
+  // vector is populated with as many temporally-ordered spectral slices
+  // as it is possible to generate from the input.  The output is cleared
+  // on each call before the new frames (if any) are added.
+  //
+  // The template parameters can be float or double.
+  template <class InputSample, class OutputSample>
+  bool ComputeComplexSpectrogram(
+      const std::vector<InputSample>& input,
+      std::vector<std::vector<std::complex<OutputSample>>>* output);
+
+  // This function works as the one above, but returns the power
+  // (the L2 norm, or the squared magnitude) of each complex value.
+  template <class InputSample, class OutputSample>
+  bool ComputeSquaredMagnitudeSpectrogram(
+      const std::vector<InputSample>& input,
+      std::vector<std::vector<OutputSample>>* output);
+
+  // Return reference to the window function used internally.
+  const std::vector<double>& GetWindow() const { return window_; }
+
+  // Return the number of frequency channels in the spectrogram.
+  int output_frequency_channels() const { return output_frequency_channels_; }
+
+ private:
+  template <class InputSample>
+  bool GetNextWindowOfSamples(const std::vector<InputSample>& input,
+                              int* input_start);
+  void ProcessCoreFFT();
+
+  int fft_length_;
+  int output_frequency_channels_;
+  int window_length_;
+  int step_length_;
+  bool initialized_;
+  int samples_to_next_step_;
+
+  std::vector<double> window_;
+  std::vector<double> fft_input_output_;
+  std::deque<double> input_queue_;
+
+  // Working data areas for the FFT routines.
+  std::vector<int> fft_integer_working_area_;
+  std::vector<double> fft_double_working_area_;
+};
+
+}  // namespace internal
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_
index df980d4..766aa11 100644 (file)
@@ -23,52 +23,107 @@ from __future__ import print_function
 
 import gast
 
-from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
 
 
-class LogicalExpressionTransformer(gast.NodeTransformer):
+# TODO(mdan): Properly extrack boolean ops according to lazy eval rules.
+# Note that this isn't completely safe either, because tensors may have control
+# dependencies.
+# Note that for loops that should be done after the loop was converted to
+# tf.while_loop so that the expanded conditionals are properly scoped.
+
+# Used to signal that an operand is safe for non-lazy evaluation.
+SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
+
+
+class LogicalExpressionTransformer(transformer.Base):
   """Converts logical expressions to corresponding TF calls."""
 
-  def __init__(self):
+  def __init__(self, context):
+    super(LogicalExpressionTransformer, self).__init__(context)
     # TODO(mdan): Look into replacing with bitwise operators instead.
     self.op_mapping = {
-        gast.And: 'tf.logical_and',
-        gast.Or: 'tf.logical_or',
-        gast.Not: 'tf.logical_not',
-        gast.Eq: 'tf.equal',
+        gast.And: 'logical_and',
+        gast.Eq: 'equal',
+        gast.Gt: 'greater',
+        gast.GtE: 'greater_equal',
+        gast.Lt: 'less',
+        gast.LtE: 'less_equal',
+        gast.Not: 'logical_not',
+        gast.NotEq: 'not_equal',
+        gast.Or: 'logical_or',
+        gast.USub: 'negative',
     }
 
+  def _expect_simple_symbol(self, operand):
+    if isinstance(operand, gast.Name):
+      return
+    if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND):
+      return
+    raise NotImplementedError(
+        'only simple local variables are supported in logical and compound '
+        'comparison expressions; for example, we support "a or b" but not '
+        '"a.x or b"; for a workaround, assign the expression to a local '
+        'variable and use that instead, for example "tmp = a.x", "tmp or b"')
+
+  def _matching_tf_op(self, operator):
+    op_type = type(operator)
+    mapped_op = self.op_mapping.get(op_type)
+    if not mapped_op:
+      raise NotImplementedError('operator %s is not yet supported' % op_type)
+    return mapped_op
+
+  def _inline_tf_op(self, op_name, args):
+    template = """
+      tf.op_name(args)
+    """
+    replacement = templates.replace(template, op_name=op_name, args=args)
+    # It's a body with a single expression, we want its value.
+    n = replacement[0].value
+    anno.setanno(n, SAFE_BOOLEAN_OPERAND, True)
+    return n
+
   def visit_Compare(self, node):
     node = self.generic_visit(node)
-    if len(node.ops) > 1:
-      raise NotImplementedError()
-    cmp_type = type(node.ops[0])
-    if cmp_type in self.op_mapping:
-      tf_function = parser.parse_str(self.op_mapping[cmp_type]).body[0].value
-      return gast.Call(
-          func=tf_function, args=[node.left, node.comparators[0]], keywords=[])
-    return node
+    ops_and_comps = list(zip(node.ops, node.comparators))
+    left = node.left
+    op_tree = None
+
+    # Repeated comparisons are converted to conjunctions:
+    #   a < b < c   ->   a < b and b < c
+    while ops_and_comps:
+      op, right = ops_and_comps.pop(0)
+      binary_comparison = self._inline_tf_op(self._matching_tf_op(op),
+                                             (left, right))
+      if isinstance(left, gast.Name) and isinstance(right, gast.Name):
+        anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
+      if op_tree:
+        self._expect_simple_symbol(right)
+        op_tree = self._inline_tf_op('logical_and',
+                                     (binary_comparison, op_tree))
+      else:
+        op_tree = binary_comparison
+      left = right
+    assert op_tree is not None
+    return op_tree
 
   def visit_UnaryOp(self, node):
     node = self.generic_visit(node)
-    if isinstance(node.op, gast.Not):
-      tf_function = parser.parse_str(self.op_mapping[type(
-          node.op)]).body[0].value
-      node = gast.Call(func=tf_function, args=[node.operand], keywords=[])
-    return node
+    return self._inline_tf_op(self._matching_tf_op(node.op), node.operand)
 
   def visit_BoolOp(self, node):
-    # TODO(mdan): A normalizer may be useful here. Use ANF?
     node = self.generic_visit(node)
-    tf_function = parser.parse_str(self.op_mapping[type(node.op)]).body[0].value
-    left = node.values[0]
-    for i in range(1, len(node.values)):
-      left = gast.Call(
-          func=tf_function, args=[left, node.values[i]], keywords=[])
-    return left
-
-
-def transform(node):
-  transformer = LogicalExpressionTransformer()
-  node = transformer.visit(node)
-  return node
+    node_values = node.values
+    right = node.values.pop()
+    self._expect_simple_symbol(right)
+    while node_values:
+      left = node_values.pop()
+      self._expect_simple_symbol(left)
+      right = self._inline_tf_op(self._matching_tf_op(node.op), (left, right))
+    return right
+
+
+def transform(node, context):
+  return LogicalExpressionTransformer(context).visit(node)
index a28326c..eb28c30 100644 (file)
@@ -32,7 +32,7 @@ class GradientsFunctionTest(converter_test_base.TestCase):
       return a == b
 
     node = self.parse_and_analyze(test_fn, {})
-    node = logical_expressions.transform(node)
+    node = logical_expressions.transform(node, self.ctx)
 
     with self.compiled(node, math_ops.equal) as result:
       with self.test_session() as sess:
@@ -45,7 +45,7 @@ class GradientsFunctionTest(converter_test_base.TestCase):
       return (a or b) and (a or b or c)
 
     node = self.parse_and_analyze(test_fn, {})
-    node = logical_expressions.transform(node)
+    node = logical_expressions.transform(node, self.ctx)
 
     with self.compiled(node, math_ops.logical_or,
                        math_ops.logical_and) as result:
index d95469e..c6f4988 100644 (file)
@@ -312,7 +312,7 @@ def node_to_graph(node, ctx, nocompile_decorators):
 
   # control_flow may create new symbols and change scopes.
   node = _static_analysis_pass(node, ctx)
-  node = logical_expressions.transform(node)
+  node = logical_expressions.transform(node, ctx)
   node = side_effect_guards.transform(node, ctx)
   node = name_scopes.transform(node, ctx)
 
index f8de8ba..7bf5c21 100644 (file)
@@ -191,6 +191,7 @@ REGISTER_OP("ConfigureDistributedTPU")
     .Output("topology: string")
     .Attr("embedding_config: string = ''")
     .Attr("tpu_embedding_config: string = ''")
+    .Attr("is_global_init: bool = false")
     .SetIsStateful()
     .SetShapeFn(shape_inference::UnknownShape)
     .Doc(R"doc(
@@ -202,6 +203,7 @@ topology.
 tpu_embedding_config: Serialized tensorflow.tpu.TPUEmbeddingConfiguration that
 describes the embedding lookups of the program.
 embedding_config: Reserved. Do not use.
+is_global_init: Reserved. Do not use.
 )doc");
 
 REGISTER_OP("ShutdownDistributedTPU")
index f1082a6..1125d2a 100644 (file)
@@ -97,7 +97,9 @@ class SimpleRendezvous : public Rendezvous {
 
 }  // namespace
 
-GraphRunner::GraphRunner(Env* env) : cpu_device_(GetCPUDevice(env)) {}
+GraphRunner::GraphRunner(Env* env)
+    : device_deleter_(GetCPUDevice(env)), device_(device_deleter_.get()) {}
+GraphRunner::GraphRunner(Device* device) : device_(device) {}
 
 GraphRunner::~GraphRunner() {}
 
@@ -105,17 +107,18 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
                         const NamedTensorList& inputs,
                         const std::vector<string>& output_names,
                         std::vector<Tensor>* outputs) {
-  if (cpu_device_ == nullptr) {
+  if (device_ == nullptr) {
     return errors::NotFound("Cannot find a device for GraphRunner.");
   }
 
   if (function_library && function_library->device() &&
-      function_library->device()->device_type() != cpu_device_->device_type()) {
-    // We are running on a CPU but the function library is for a non-CPU device,
-    // so just ignore the function_library.
+      function_library->device()->device_type() != device_->device_type()) {
+    // Mismatch between function_library's device_type and device_'s
+    // device_type.
     // TODO(matthewmurray) Can we create a new FunctionLibraryRuntime that is
-    // identical to function_library except that it uses CPU?
-    VLOG(1) << "Cannot run on CPU device with a function library for a "
+    // identical to function_library except that it uses the given 'device_'?
+    VLOG(1) << "Cannot run on: " << device_->device_type()
+            << " with a function library for a "
             << function_library->device()->device_type() << " device.";
     function_library = nullptr;
   }
@@ -146,8 +149,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
   subgraph::RewriteGraphMetadata metadata;
   TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
       graph_to_run.get(), input_names, output_names, {} /* target nodes */,
-      cpu_device_->attributes(), false /* use_function_convention */,
-      &metadata));
+      device_->attributes(), false /* use_function_convention */, &metadata));
 
   // Create the local executor and the Rendezvous for fetching back the
   // constants.
@@ -158,13 +160,12 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
 
   LocalExecutorParams params;
   // The ownership of the output tensors are bound to this device's lifetime.
-  params.device = cpu_device_.get();
+  params.device = device_;
   params.function_library = function_library;
   const int producer = graph_to_run->versions().producer();
   params.create_kernel = [this, producer](const NodeDef& ndef,
                                           OpKernel** kernel) {
-    return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef, producer,
-                                 kernel);
+    return CreateNonCachedKernel(device_, nullptr, ndef, producer, kernel);
   };
   params.delete_kernel = [](OpKernel* kernel) { delete kernel; };
 
index 1e4ae77..1c4b2b7 100644 (file)
@@ -36,12 +36,14 @@ namespace tensorflow {
 // This class is only meant for internal use where one needs to
 // partially evaluate inexpensive nodes in a graph, such as for shape
 // inference or for constant folding.  Because of its limited, simple
-// use-cases, it executes all computation on the CPU and is not meant
-// to be particularly lightweight, fast, or efficient.
+// use-cases, it executes all computation on the given device (CPU by default)
+// and is not meant to be particularly lightweight, fast, or efficient.
 class GraphRunner {
  public:
   // REQUIRES: `env` is not nullptr.
   GraphRunner(Env* env);
+  // REQUIRES: 'device' is not nullptr. Not owned.
+  GraphRunner(Device* device);
   ~GraphRunner();
 
   // Function semantics for `inputs`, `output_names` and `outputs`
@@ -59,7 +61,8 @@ class GraphRunner {
              std::vector<Tensor>* outputs);
 
  private:
-  std::unique_ptr<Device> cpu_device_;
+  std::unique_ptr<Device> device_deleter_;
+  Device* const device_;
 };
 
 }  // namespace tensorflow