if "__main__" in __name__:
- inpDims = (100, 24, 24, 2)
- dummy_input = np.random.random_sample(inpDims)
- gdef = getSimpleGraphDef()
- trt_graph = trt.create_inference_graph(gdef, ["output"],
- inpDims[0]) # Get optimized graph
- o1 = runGraph(gdef, dummy_input)
- o2 = runGraph(trt_graph, dummy_input)
- o3 = runGraph(trt_graph, dummy_input)
- assert (np.array_equal(o1, o2))
- assert (np.array_equal(o2, o3))
+ inp_dims = (100, 24, 24, 2)
+ dummy_input = np.random.random_sample(inp_dims)
+ gdef = get_simple_graph_def()
+ # Get optimized graph
+ trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0])
+ o1 = run_graph(gdef, dummy_input)
+ o2 = run_graph(trt_graph, dummy_input)
++ o3 = run_graph(trt_graph, dummy_input)
+ assert np.array_equal(o1, o2)
++ assert np.array_equal(o3, o2) # sanity check
print("Pass")