Fix tf.nn.fractional_max_pool output have same batch size when feed with different...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 3 Jan 2018 23:10:58 +0000 (15:10 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 3 Jan 2018 23:15:02 +0000 (15:15 -0800)
PiperOrigin-RevId: 180724096

tensorflow/core/kernels/BUILD
tensorflow/core/kernels/fractional_avg_pool_op.cc
tensorflow/core/kernels/fractional_max_pool_op.cc
tensorflow/python/kernel_tests/BUILD
tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
tensorflow/python/kernel_tests/fractional_max_pool_op_test.py

index ae39c4522dbf304b5156478f2b5571180cba567d..a1b62179f586fecd146ef350256a7acfc9f3427f 100644 (file)
@@ -3370,6 +3370,7 @@ tf_kernel_library(
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
         "//tensorflow/core:nn_ops_op_lib",
         "//third_party/eigen3",
     ],
index bfdb7b4a1e4cc9af9745896c5ff1341f00efdffe..47f4189c30f10644ca7b040677ebadf439a9dc75 100644 (file)
@@ -24,6 +24,7 @@ limitations under the License.
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/util/guarded_philox_random.h"
@@ -47,9 +48,20 @@ class FractionalAvgPoolOp : public OpKernel {
         errors::Unimplemented("Fractional average pooling is not yet "
                               "supported on the batch nor channel dimension."));
     OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
-    pooling_region_generated_ = false;
-    // Initialize philox random generator.
-    OP_REQUIRES_OK(context, generator_.Init(context));
+    OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
+    OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
+    if (deterministic_) {
+      // If both seeds are not set when deterministic_ is true, force set seeds.
+      if ((seed_ == 0) && (seed2_ == 0)) {
+        seed_ = random::New64();
+        seed2_ = random::New64();
+      }
+    } else {
+      OP_REQUIRES(
+          context, (seed_ == 0) && (seed2_ == 0),
+          errors::InvalidArgument(
+              "Both seed and seed2 should be 0 if deterministic is false."));
+    }
   }
 
   void Compute(OpKernelContext* context) override {
@@ -64,47 +76,35 @@ class FractionalAvgPoolOp : public OpKernel {
     OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
                 errors::InvalidArgument("tensor_in must be 4-dimensional"));
 
+    std::vector<int> input_size(tensor_in_and_out_dims);
+    std::vector<int> output_size(tensor_in_and_out_dims);
     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
-      input_size_.push_back(tensor_in.dim_size(i));
+      input_size[i] = tensor_in.dim_size(i);
     }
     // Output size.
     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
-      output_size_.push_back(
-          static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
-      DCHECK_GT(output_size_[i], 0);
+      output_size[i] =
+          static_cast<int>(floor(input_size[i] / pooling_ratio_[i]));
+      DCHECK_GT(output_size[i], 0);
     }
 
     // Generate pooling sequence.
     std::vector<int64> row_cum_seq;
     std::vector<int64> col_cum_seq;
-    if (deterministic_) {
-      if (pooling_region_generated_) {
-        row_cum_seq = row_cum_seq_;
-        col_cum_seq = col_cum_seq_;
-      } else {
-        row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
-                                              &generator_, pseudo_random_);
-        col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
-                                              &generator_, pseudo_random_);
-        mutex_lock lock(mu_);
-        row_cum_seq_ = row_cum_seq;
-        col_cum_seq_ = col_cum_seq;
-        pooling_region_generated_ = true;
-      }
-    } else {
-      row_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
-                                            &generator_, pseudo_random_);
-      col_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
-                                            &generator_, pseudo_random_);
-    }
+    GuardedPhiloxRandom generator;
+    generator.Init(seed_, seed2_);
+    row_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1],
+                                          &generator, pseudo_random_);
+    col_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2],
+                                          &generator, pseudo_random_);
 
     // Prepare output.
     Tensor* output_tensor = nullptr;
-    OP_REQUIRES_OK(context,
-                   context->allocate_output(
-                       0, TensorShape({output_size_[0], output_size_[1],
-                                       output_size_[2], output_size_[3]}),
-                       &output_tensor));
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                0,
+                                TensorShape({output_size[0], output_size[1],
+                                             output_size[2], output_size[3]}),
+                                &output_tensor));
     Tensor* output_row_seq_tensor = nullptr;
     OP_REQUIRES_OK(context,
                    context->allocate_output(
@@ -116,12 +116,11 @@ class FractionalAvgPoolOp : public OpKernel {
                        2, TensorShape({static_cast<int64>(col_cum_seq.size())}),
                        &output_col_seq_tensor));
 
-    ConstEigenMatrixMap in_mat(
-        tensor_in.flat<T>().data(), input_size_[3],
-        input_size_[2] * input_size_[1] * input_size_[0]);
+    ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3],
+                               input_size[2] * input_size[1] * input_size[0]);
 
-    EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3],
-                           output_size_[2] * output_size_[1] * output_size_[0]);
+    EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3],
+                           output_size[2] * output_size[1] * output_size[0]);
     // out_count corresponds to number of elements in each pooling cell.
     Eigen::Matrix<T, Eigen::Dynamic, 1> out_count(out_mat.cols());
 
@@ -146,9 +145,9 @@ class FractionalAvgPoolOp : public OpKernel {
     // 1: row / row
     // 2: col / col
     // 3: depth / channel
-    const int64 row_max = input_size_[1] - 1;
-    const int64 col_max = input_size_[2] - 1;
-    for (int64 b = 0; b < input_size_[0]; ++b) {
+    const int64 row_max = input_size[1] - 1;
+    const int64 col_max = input_size[2] - 1;
+    for (int64 b = 0; b < input_size[0]; ++b) {
       // row sequence.
       for (int64 hs = 0; hs < row_cum_seq.size() - 1; ++hs) {
         // row start and end.
@@ -160,7 +159,7 @@ class FractionalAvgPoolOp : public OpKernel {
         // col sequence.
         for (int64 ws = 0; ws < col_cum_seq.size() - 1; ++ws) {
           const int64 out_offset =
-              (b * output_size_[1] + hs) * output_size_[2] + ws;
+              (b * output_size[1] + hs) * output_size[2] + ws;
           // col start and end.
           const int64 col_start = col_cum_seq[ws];
           int64 col_end =
@@ -169,7 +168,7 @@ class FractionalAvgPoolOp : public OpKernel {
           for (int64 h = row_start; h <= row_end; ++h) {
             for (int64 w = col_start; w <= col_end; ++w) {
               const int64 in_offset =
-                  (b * input_size_[1] + h) * input_size_[2] + w;
+                  (b * input_size[1] + h) * input_size[2] + w;
               out_mat.col(out_offset) += in_mat.col(in_offset);
               out_count(out_offset)++;
             }
@@ -183,18 +182,11 @@ class FractionalAvgPoolOp : public OpKernel {
 
  private:
   bool deterministic_;
-  // meaningful only when deterministic_ is true.
-  mutex mu_;
-  std::vector<int64> row_cum_seq_;
-  std::vector<int64> col_cum_seq_;
-  bool pooling_region_generated_;
-
-  std::vector<int32> input_size_;
-  std::vector<int32> output_size_;
+  int64 seed_;
+  int64 seed2_;
   std::vector<float> pooling_ratio_;
   bool pseudo_random_;
   bool overlapping_;
-  GuardedPhiloxRandom generator_;
 };
 
 #define REGISTER_FRACTIONALAVGPOOL(type)                                      \
index 33d73c84776341cf08243d828ee372456554e2cf..cf580adab256bf055f206f44a5996c1e5487540a 100644 (file)
@@ -24,6 +24,7 @@ limitations under the License.
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/util/guarded_philox_random.h"
@@ -50,9 +51,20 @@ class FractionalMaxPoolOp : public OpKernel {
                               "supported on the batch nor channel dimension."));
 
     OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
-    pooling_region_generated_ = false;
-    // Initialize philox random generator.
-    OP_REQUIRES_OK(context, generator_.Init(context));
+    OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
+    OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
+    if (deterministic_) {
+      // If both seeds are not set when deterministic_ is true, force set seeds.
+      if ((seed_ == 0) && (seed2_ == 0)) {
+        seed_ = random::New64();
+        seed2_ = random::New64();
+      }
+    } else {
+      OP_REQUIRES(
+          context, (seed_ == 0) && (seed2_ == 0),
+          errors::InvalidArgument(
+              "Both seed and seed2 should be 0 if deterministic is false."));
+    }
   }
 
   void Compute(OpKernelContext* context) override {
@@ -67,49 +79,37 @@ class FractionalMaxPoolOp : public OpKernel {
     OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
                 errors::InvalidArgument("tensor_in must be 4-dimensional"));
 
+    std::vector<int> input_size(tensor_in_and_out_dims);
+    std::vector<int> output_size(tensor_in_and_out_dims);
     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
-      input_size_.push_back(tensor_in.dim_size(i));
+      input_size[i] = tensor_in.dim_size(i);
     }
     // Output size.
     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
       // This must match the same logic in the shape function in
       // core/ops/nn_ops.cc.
-      output_size_.push_back(
-          static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
-      DCHECK_GT(output_size_[i], 0);
+      output_size[i] =
+          static_cast<int>(floor(input_size[i] / pooling_ratio_[i]));
+      DCHECK_GT(output_size[i], 0);
     }
 
     // Generate pooling sequence.
     std::vector<int64> height_cum_seq;
     std::vector<int64> width_cum_seq;
-    if (deterministic_) {
-      if (pooling_region_generated_) {
-        height_cum_seq = height_cum_seq_;
-        width_cum_seq = width_cum_seq_;
-      } else {
-        height_cum_seq = GeneratePoolingSequence(
-            input_size_[1], output_size_[1], &generator_, pseudo_random_);
-        width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
-                                                &generator_, pseudo_random_);
-        mutex_lock lock(mu_);
-        height_cum_seq_ = height_cum_seq;
-        width_cum_seq_ = width_cum_seq;
-        pooling_region_generated_ = true;
-      }
-    } else {
-      height_cum_seq = GeneratePoolingSequence(input_size_[1], output_size_[1],
-                                               &generator_, pseudo_random_);
-      width_cum_seq = GeneratePoolingSequence(input_size_[2], output_size_[2],
-                                              &generator_, pseudo_random_);
-    }
+    GuardedPhiloxRandom generator;
+    generator.Init(seed_, seed2_);
+    height_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1],
+                                             &generator, pseudo_random_);
+    width_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2],
+                                            &generator, pseudo_random_);
 
     // Prepare output.
     Tensor* output_tensor = nullptr;
-    OP_REQUIRES_OK(context,
-                   context->allocate_output(
-                       0, TensorShape({output_size_[0], output_size_[1],
-                                       output_size_[2], output_size_[3]}),
-                       &output_tensor));
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                0,
+                                TensorShape({output_size[0], output_size[1],
+                                             output_size[2], output_size[3]}),
+                                &output_tensor));
     Tensor* output_height_seq_tensor = nullptr;
     OP_REQUIRES_OK(
         context,
@@ -122,12 +122,11 @@ class FractionalMaxPoolOp : public OpKernel {
                      2, TensorShape({static_cast<int64>(width_cum_seq.size())}),
                      &output_width_seq_tensor));
 
-    ConstEigenMatrixMap in_mat(
-        tensor_in.flat<T>().data(), input_size_[3],
-        input_size_[2] * input_size_[1] * input_size_[0]);
+    ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3],
+                               input_size[2] * input_size[1] * input_size[0]);
 
-    EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size_[3],
-                           output_size_[2] * output_size_[1] * output_size_[0]);
+    EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3],
+                           output_size[2] * output_size[1] * output_size[0]);
 
     // Initializes the output tensor with MIN<T>.
     output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
@@ -149,9 +148,9 @@ class FractionalMaxPoolOp : public OpKernel {
     // 1: height / row
     // 2: width / col
     // 3: depth / channel
-    const int64 height_max = input_size_[1] - 1;
-    const int64 width_max = input_size_[2] - 1;
-    for (int64 b = 0; b < input_size_[0]; ++b) {
+    const int64 height_max = input_size[1] - 1;
+    const int64 width_max = input_size[2] - 1;
+    for (int64 b = 0; b < input_size[0]; ++b) {
       // height sequence.
       for (int64 hs = 0; hs < height_cum_seq.size() - 1; ++hs) {
         // height start and end.
@@ -163,7 +162,7 @@ class FractionalMaxPoolOp : public OpKernel {
         // width sequence.
         for (int64 ws = 0; ws < width_cum_seq.size() - 1; ++ws) {
           const int64 out_offset =
-              (b * output_size_[1] + hs) * output_size_[2] + ws;
+              (b * output_size[1] + hs) * output_size[2] + ws;
           // width start and end.
           const int64 width_start = width_cum_seq[ws];
           int64 width_end =
@@ -172,7 +171,7 @@ class FractionalMaxPoolOp : public OpKernel {
           for (int64 h = height_start; h <= height_end; ++h) {
             for (int64 w = width_start; w <= width_end; ++w) {
               const int64 in_offset =
-                  (b * input_size_[1] + h) * input_size_[2] + w;
+                  (b * input_size[1] + h) * input_size[2] + w;
               out_mat.col(out_offset) =
                   out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
             }
@@ -184,18 +183,11 @@ class FractionalMaxPoolOp : public OpKernel {
 
  private:
   bool deterministic_;
-  // meaningful only when deterministic_ is true.
-  mutex mu_;
-  std::vector<int64> height_cum_seq_;
-  std::vector<int64> width_cum_seq_;
-  bool pooling_region_generated_;
-
-  std::vector<int32> input_size_;
-  std::vector<int32> output_size_;
+  int64 seed_;
+  int64 seed2_;
   std::vector<float> pooling_ratio_;
   bool pseudo_random_;
   bool overlapping_;
-  GuardedPhiloxRandom generator_;
 };
 
 #define REGISTER_FRACTIONALMAXPOOL(type)                                      \
@@ -243,15 +235,13 @@ class FractionalMaxPoolGradOp : public OpKernel {
 
     // Just to make it similar to FractionalMaxPoolOp.
     constexpr int tensor_in_and_out_dims = 4;
-    std::vector<int64> input_size;
-    std::vector<int64> output_size;
-    input_size.reserve(tensor_in_and_out_dims);
+    std::vector<int64> input_size(tensor_in_and_out_dims);
+    std::vector<int64> output_size(tensor_in_and_out_dims);
     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
-      input_size.push_back(tensor_in.dim_size(i));
+      input_size[i] = tensor_in.dim_size(i);
     }
-    output_size.reserve(tensor_in_and_out_dims);
     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
-      output_size.push_back(tensor_out.dim_size(i));
+      output_size[i] = tensor_out.dim_size(i);
     }
 
     // ---------
index d7403fe6ee73eab9f681ef43b1c183e093729f6c..e5f65edd39b63668d2e7dcafe9929778906c55bb 100644 (file)
@@ -369,6 +369,7 @@ tf_py_test(
     srcs = ["fractional_avg_pool_op_test.py"],
     additional_deps = [
         "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:nn_grad",
@@ -383,6 +384,7 @@ tf_py_test(
     srcs = ["fractional_max_pool_op_test.py"],
     additional_deps = [
         "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:nn_grad",
index 48a51c8072416f3d494129f18912d67491fa5281..feec9934e459590bb1dd0bc5c7cf40013d3d8b88 100644 (file)
@@ -23,6 +23,8 @@ import math
 import numpy as np
 
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_nn_ops
 from tensorflow.python.ops import gradient_checker
 from tensorflow.python.ops import nn_ops
@@ -310,6 +312,35 @@ class FractionalAvgTest(test.TestCase):
     self._ValidateFractionalAvgPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
                                           overlapping)
 
+  def testDifferentInputTensorShape(self):
+    """Runs the operation in one session with different input tensor shapes."""
+    with self.test_session() as sess:
+      input_holder = array_ops.placeholder(dtypes.float32,
+                                           [None, None, None, 3])
+      pooling_ratio = [1, 1.5, 1.5, 1]
+      pseudo_random = False
+      overlapping = False
+      p, r, c = nn_ops.fractional_avg_pool(
+          input_holder,
+          pooling_ratio,
+          pseudo_random,
+          overlapping,
+          deterministic=True,
+          seed=self._SEED,
+          seed2=self._SEED2)
+      # First run.
+      input_a = np.zeros([3, 32, 32, 3])
+      actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_a})
+      expected = self._GetExpectedFractionalAvgPoolResult(
+          input_a, row_seq, col_seq, overlapping)
+      self.assertSequenceEqual(expected.shape, actual.shape)
+      # Second run.
+      input_b = np.zeros([4, 60, 60, 3])
+      actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_b})
+      expected = self._GetExpectedFractionalAvgPoolResult(
+          input_b, row_seq, col_seq, overlapping)
+      self.assertSequenceEqual(expected.shape, actual.shape)
+
 
 class FractionalAvgPoolGradTest(test.TestCase):
   """Tests for FractionalAvgPoolGrad.
index d380c31de35510c415420b3302fe1d4ff07877d2..5983ae7759dbf3eb2db9867def829ce8dbeb4b73 100644 (file)
@@ -23,6 +23,8 @@ import math
 import numpy as np
 
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_nn_ops
 from tensorflow.python.ops import gradient_checker
 from tensorflow.python.ops import nn_ops
@@ -281,6 +283,35 @@ class FractionalMaxPoolTest(test.TestCase):
     self._ValidateFractionalMaxPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
                                           overlapping)
 
+  def testDifferentInputTensorShape(self):
+    """Runs the operation in one session with different input tensor shapes."""
+    with self.test_session() as sess:
+      input_holder = array_ops.placeholder(dtypes.float32,
+                                           [None, None, None, 3])
+      pooling_ratio = [1, 1.5, 1.5, 1]
+      pseudo_random = False
+      overlapping = False
+      p, r, c = nn_ops.fractional_max_pool(
+          input_holder,
+          pooling_ratio,
+          pseudo_random,
+          overlapping,
+          deterministic=True,
+          seed=self._SEED,
+          seed2=self._SEED2)
+      # First run.
+      input_a = np.zeros([3, 32, 32, 3])
+      actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_a})
+      expected = self._GetExpectedFractionalMaxPoolResult(
+          input_a, row_seq, col_seq, overlapping)
+      self.assertSequenceEqual(expected.shape, actual.shape)
+      # Second run.
+      input_b = np.zeros([4, 45, 45, 3])
+      actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_b})
+      expected = self._GetExpectedFractionalMaxPoolResult(
+          input_b, row_seq, col_seq, overlapping)
+      self.assertSequenceEqual(expected.shape, actual.shape)
+
 
 class FractionalMaxPoolGradTest(test.TestCase):
   """Tests for FractionalMaxPoolGrad.