Fix linter errors in test_tftrt.py
authorgracehoney <31743510+aaroey@users.noreply.github.com>
Mon, 12 Feb 2018 17:41:43 +0000 (09:41 -0800)
committergracehoney <31743510+aaroey@users.noreply.github.com>
Mon, 12 Feb 2018 17:41:43 +0000 (09:41 -0800)
tensorflow/contrib/tensorrt/test/test_tftrt.py

index 927a3e4..7195666 100644 (file)
@@ -18,33 +18,33 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import numpy as np
 import tensorflow as tf
 import tensorflow.contrib.tensorrt as trt
-import numpy as np
 
 
-def getSimpleGraphDef():
+def get_simple_graph_def():
   """Create a simple graph and return its graph_def"""
   g = tf.Graph()
   with g.as_default():
-    A = tf.placeholder(dtype=tf.float32, shape=(None, 24, 24, 2), name="input")
+    a = tf.placeholder(dtype=tf.float32, shape=(None, 24, 24, 2), name="input")
     e = tf.constant(
         [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
         name="weights",
         dtype=tf.float32)
     conv = tf.nn.conv2d(
-        input=A, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
+        input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
     b = tf.constant([4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32)
     t = tf.nn.bias_add(conv, b, name="biasAdd")
     relu = tf.nn.relu(t, "relu")
     idty = tf.identity(relu, "ID")
     v = tf.nn.max_pool(
         idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
-    out = tf.squeeze(v, name="output")
+    tf.squeeze(v, name="output")
   return g.as_graph_def()
 
 
-def runGraph(gdef, dumm_inp):
+def run_graph(gdef, dumm_inp):
   gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50)
   tf.reset_default_graph()
   g = tf.Graph()
@@ -60,12 +60,12 @@ def runGraph(gdef, dumm_inp):
 
 
 if "__main__" in __name__:
-  inpDims = (100, 24, 24, 2)
-  dummy_input = np.random.random_sample(inpDims)
-  gdef = getSimpleGraphDef()
-  trt_graph = trt.create_inference_graph(gdef, ["output"],
-                                         inpDims[0])  # Get optimized graph
-  o1 = runGraph(gdef, dummy_input)
-  o2 = runGraph(trt_graph, dummy_input)
-  assert (np.array_equal(o1, o2))
+  inp_dims = (100, 24, 24, 2)
+  dummy_input = np.random.random_sample(inp_dims)
+  gdef = get_simple_graph_def()
+  # Get optimized graph
+  trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0])
+  o1 = run_graph(gdef, dummy_input)
+  o2 = run_graph(trt_graph, dummy_input)
+  assert np.array_equal(o1, o2)
   print("Pass")