Fix a tensorflow test bug. (#3165)
authorlixiaoquan <radioheads@163.com>
Fri, 10 May 2019 17:14:39 +0000 (01:14 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 10 May 2019 17:14:38 +0000 (10:14 -0700)
Length of input_shape isn't always 4.

tests/python/frontend/tensorflow/test_forward.py

index 8dd538a..1579769 100644 (file)
@@ -185,7 +185,7 @@ def _test_pooling_iteration(input_shape, **kwargs):
 def _test_pooling(input_shape, **kwargs):
     _test_pooling_iteration(input_shape, **kwargs)
 
-    if is_gpu_available():
+    if is_gpu_available() and (len(input_shape) == 4):
         input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
         kwargs['data_format'] = 'NCHW'
         _test_pooling_iteration(input_shape, **kwargs)