Improve shape inference for tf.contrib.signal.frame.
authorRJ Ryan <rjryan@google.com>
Tue, 1 May 2018 19:02:59 +0000 (12:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 1 May 2018 19:05:51 +0000 (12:05 -0700)
PiperOrigin-RevId: 194972934

tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
tensorflow/contrib/signal/python/ops/shape_ops.py

index 64cc8c7..f132050 100644 (file)
@@ -119,7 +119,7 @@ class FrameTest(test.TestCase):
     frame_step = 1
     result = shape_ops.frame(signal, frame_length, frame_step,
                              pad_end=True, pad_value=99, axis=1)
-    self.assertEqual([1, None, None, 3, 4], result.shape.as_list())
+    self.assertEqual([1, 2, None, 3, 4], result.shape.as_list())
 
     result = shape_ops.frame(signal, frame_length, frame_step,
                              pad_end=False, axis=1)
index 1ddc294..91862f0 100644 (file)
@@ -43,13 +43,13 @@ def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis):
   outer_dimensions = signal_shape[:axis]
   inner_dimensions = signal_shape[axis:][1:]
   if signal_shape and frame_axis is not None:
-    if frame_step and frame_length is not None:
-      if pad_end:
-        # Double negative is so that we round up.
-        num_frames = -(-frame_axis // frame_step)
-      else:
-        num_frames = (frame_axis - frame_length + frame_step) // frame_step
-      num_frames = max(0, num_frames)
+    if frame_step is not None and pad_end:
+      # Double negative is so that we round up.
+      num_frames = max(0, -(-frame_axis // frame_step))
+    elif frame_step is not None and frame_length is not None:
+      assert not pad_end
+      num_frames = max(
+          0, (frame_axis - frame_length + frame_step) // frame_step)
   return outer_dimensions + [num_frames, frame_length] + inner_dimensions