From: Sami Kama Date: Mon, 12 Feb 2018 21:37:42 +0000 (-0800) Subject: Fix merge issues X-Git-Tag: upstream/v1.7.0~202^2~6 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=14948a8aeb23924687bbe02ec6e586b37acb102f;p=platform%2Fupstream%2Ftensorflow.git Fix merge issues --- 14948a8aeb23924687bbe02ec6e586b37acb102f diff --cc tensorflow/contrib/tensorrt/test/test_tftrt.py index adf3438,7195666..69fccd3 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@@ -60,14 -60,12 +60,14 @@@ def run_graph(gdef, dumm_inp) if "__main__" in __name__: - inpDims = (100, 24, 24, 2) - dummy_input = np.random.random_sample(inpDims) - gdef = getSimpleGraphDef() - trt_graph = trt.create_inference_graph(gdef, ["output"], - inpDims[0]) # Get optimized graph - o1 = runGraph(gdef, dummy_input) - o2 = runGraph(trt_graph, dummy_input) - o3 = runGraph(trt_graph, dummy_input) - assert (np.array_equal(o1, o2)) - assert (np.array_equal(o2, o3)) + inp_dims = (100, 24, 24, 2) + dummy_input = np.random.random_sample(inp_dims) + gdef = get_simple_graph_def() + # Get optimized graph + trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0]) + o1 = run_graph(gdef, dummy_input) + o2 = run_graph(trt_graph, dummy_input) ++ o3 = run_graph(trt_graph, dummy_input) + assert np.array_equal(o1, o2) ++ assert np.array_equal(o3, o2) # sanity check print("Pass")