Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / assign_elimination.py
index 2a6dc07..6550c27 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ import logging as log
 import networkx as nx
 
 from mo.front.common.replacement import FrontReplacementOp
+from mo.graph.graph import Node, Graph
 from mo.utils.error import Error
 
 
@@ -26,7 +27,7 @@ class AssignElimination(FrontReplacementOp):
     op = "Assign"
     enabled = True
 
-    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_sub_graph(self, graph: Graph, match: dict):
         node = match['op']
         # here we request all data flow output edges (control flow edges will not be listed)
         out_edges = node.out_edges()
@@ -41,7 +42,7 @@ class AssignSubElimination(FrontReplacementOp):
     op = "AssignSub"
     enabled = True
 
-    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_sub_graph(self, graph: Graph, match: dict):
         node = match['op']
         # here we request all data flow output edges (control flow edges will not be listed)
         out_edges = node.out_edges()
@@ -56,7 +57,7 @@ class AssignAddElimination(FrontReplacementOp):
     op = "AssignAdd"
     enabled = True
 
-    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_sub_graph(self, graph: Graph, match: dict):
         node = match['op']
         # here we request all data flow output edges (control flow edges will not be listed)
         out_edges = node.out_edges()
@@ -65,3 +66,18 @@ class AssignAddElimination(FrontReplacementOp):
             log.debug('AssignAdd op was removed {}'.format(node.id))
         else:
             raise Error('Data flow edge coming out of AssignAdd node {}'.format(node.id))
+
+
+class AssertElimination(FrontReplacementOp):
+    op = "Assert"
+    enabled = True
+
+    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+        node = match['op']
+        # here we request all data flow output edges (control flow edges will not be listed)
+        out_edges = node.out_edges()
+        if len(out_edges) == 0:
+            graph.remove_node(node.id)
+            log.debug('Assert op was removed {}'.format(node.id))
+        else:
+            raise Error('Data flow edge coming out of Assert node {}'.format(node.id))