--- /dev/null
+import graphviz # type: ignore[import]
+
+def get_layer_name_type(layer):
+ return "\n".join(f"{i}" for i in [layer.name, layer.type])
+
+def trt_network_to_dot_graph(network):
+ dot = graphviz.Digraph(comment="Network")
+
+ # add nodes (layers)
+ for i in range(network.num_layers):
+ layer = network.get_layer(i)
+ dot.node(get_layer_name_type(layer))
+
+ # add nodes (inputs)
+ for i in range(network.num_inputs):
+ dot.node(network.get_input(i).name)
+
+ # add nodes (outputs)
+ for i in range(network.num_outputs):
+ dot.node(network.get_output(i).name)
+
+ # add layer->layer edges
+ for a in range(network.num_layers):
+ layer_a = network.get_layer(a)
+
+ for b in range(network.num_layers):
+ layer_b = network.get_layer(b)
+
+ for i in range(layer_a.num_outputs):
+ output_i = layer_a.get_output(i)
+
+ for j in range(layer_b.num_inputs):
+ input_j = layer_b.get_input(j)
+
+ if output_i == input_j:
+ dot.edge(get_layer_name_type(layer_a), get_layer_name_type(layer_b), label=str(input_j.shape))
+
+ # add input->layer edges
+ for i in range(network.num_inputs):
+ input_i = network.get_input(i)
+
+ for b in range(network.num_layers):
+ layer_b = network.get_layer(b)
+
+ for j in range(layer_b.num_inputs):
+ input_j = layer_b.get_input(j)
+
+ if input_i == input_j:
+ dot.edge(input_i.name, get_layer_name_type(layer_b), label=str(input_j.shape))
+
+ # add layer->output edges
+ for i in range(network.num_outputs):
+ input_i = network.get_output(i)
+
+ for b in range(network.num_layers):
+ layer_b = network.get_layer(b)
+
+ for j in range(layer_b.num_outputs):
+ input_j = layer_b.get_output(j)
+
+ if input_i == input_j:
+ dot.edge(get_layer_name_type(layer_b), input_i.name, label=str(input_j.shape))
+
+ return dot