Update tf.nn.[max,avg]_pool to specify that it accepts list/tuple stride and kernel...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 9 May 2018 20:28:00 +0000 (13:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 20:49:41 +0000 (13:49 -0700)
If you actually specify a tensor argument here, you get the error:
TypeError: Expected list for 'ksize' argument to 'avg_pool' Op, not <tf.Tensor 'Const_1:0' shape=(4,) dtype=int32>.
PiperOrigin-RevId: 196019507

tensorflow/python/ops/nn_ops.py

index cd07550..09a4425 100644 (file)
@@ -2100,11 +2100,10 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
   Args:
     value: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
       `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
-    ksize: A 1-D int Tensor of 4 elements.
-      The size of the window for each dimension of the input tensor.
-    strides: A 1-D int Tensor of 4 elements
-      The stride of the sliding window for each dimension of the
-      input tensor.
+    ksize: A list or tuple of 4 ints. The size of the window for each dimension
+      of the input tensor.
+    strides: A list or tuple of 4 ints. The stride of the sliding window for
+      each dimension of the input tensor.
     padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
       See the @{tf.nn.convolution$comment here}
     data_format: A string. 'NHWC' and 'NCHW' are supported.
@@ -2130,10 +2129,10 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
 
   Args:
     value: A 4-D `Tensor` of the format specified by `data_format`.
-    ksize: A 1-D int Tensor of 4 elements.  The size of the window for
+    ksize: A list or tuple of 4 ints. The size of the window for each dimension
+      of the input tensor.
+    strides: A list or tuple of 4 ints. The stride of the sliding window for
       each dimension of the input tensor.
-    strides: A 1-D int Tensor of 4 elements.  The stride of the sliding
-      window for each dimension of the input tensor.
     padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
       See the @{tf.nn.convolution$comment here}
     data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.