Fix merge issues
authorSami Kama <sami.kama.git@gmail.com>
Mon, 12 Feb 2018 21:37:42 +0000 (13:37 -0800)
committerSami Kama <sami.kama.git@gmail.com>
Mon, 12 Feb 2018 21:37:42 +0000 (13:37 -0800)
1  2 
tensorflow/contrib/tensorrt/test/test_tftrt.py

@@@ -60,14 -60,12 +60,14 @@@ def run_graph(gdef, dumm_inp)
  
  
  if "__main__" in __name__:
-   inpDims = (100, 24, 24, 2)
-   dummy_input = np.random.random_sample(inpDims)
-   gdef = getSimpleGraphDef()
-   trt_graph = trt.create_inference_graph(gdef, ["output"],
-                                          inpDims[0])  # Get optimized graph
-   o1 = runGraph(gdef, dummy_input)
-   o2 = runGraph(trt_graph, dummy_input)
-   o3 = runGraph(trt_graph, dummy_input)
-   assert (np.array_equal(o1, o2))
-   assert (np.array_equal(o2, o3))
+   inp_dims = (100, 24, 24, 2)
+   dummy_input = np.random.random_sample(inp_dims)
+   gdef = get_simple_graph_def()
+   # Get optimized graph
+   trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0])
+   o1 = run_graph(gdef, dummy_input)
+   o2 = run_graph(trt_graph, dummy_input)
++  o3 = run_graph(trt_graph, dummy_input)
+   assert np.array_equal(o1, o2)
++  assert np.array_equal(o3, o2)  # sanity check
    print("Pass")