Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / freeze_placeholder_value.py
index 2775738..cda5a95 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 
 import logging as log
 
-import networkx as nx
 import numpy as np
 
 from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import erase_node
+from mo.graph.graph import Graph
 from mo.middle.passes.convert_data_type import SUPPORTED_DATA_TYPES
 from mo.ops.const import Const
 from mo.utils.error import Error
@@ -28,13 +27,19 @@ from mo.utils.error import Error
 
 class FreezePlaceholderValue(FrontReplacementSubgraph):
     """
-    Replaces existing placeholder to Constant node with provided value. It takes value from raplacement_dict as string
-    and casts it to actual node data type
-    :param replacement_dict: dictionary with node names as keys and strings as values
+    Replaces existing placeholder to Constant node with provided value. It takes value from freeze_placeholder as
+    a string and casts it to actual node data type
     """
+    enabled = True
+    graph_condition = [lambda graph: graph.graph['freeze_placeholder'] is not None]
 
-    enabled = False
-    replacement_dict = dict()
+    def run_after(self):
+        from extensions.front.restore_ports import RestorePorts
+        return [RestorePorts]
+
+    def run_before(self):
+        from extensions.front.pass_separator import FrontStart
+        return [FrontStart]
 
     @staticmethod
     def pattern():
@@ -43,15 +48,15 @@ class FreezePlaceholderValue(FrontReplacementSubgraph):
             edges=[]
         )
 
-    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_sub_graph(self, graph: Graph, match: dict):
         ph = match['placeholder']
-        if ph.name in self.replacement_dict:
+        if ph.name in graph.graph['freeze_placeholder']:
             name = ph.name
             if ph.has_and_set('data_type'):
                 data_type = ph.data_type
             else:
                 data_type = SUPPORTED_DATA_TYPES[graph.graph['cmd_params'].data_type][0]
-            string_value = self.replacement_dict[name]
+            string_value = graph.graph['freeze_placeholder'][name]
             try:
                 if data_type != np.bool:
                     value = np.array(string_value, dtype=data_type)
@@ -76,7 +81,7 @@ class FreezePlaceholderValue(FrontReplacementSubgraph):
             new_node = Const(graph).create_node(
                 attrs={'value': value, 'data_type': type(value), 'name': name + '/const_placeholder',
                        'shape': ph.shape})
-            erase_node(ph)
+            graph.erase_node(ph)
             graph.add_edges_from([(new_node.id, v, attrs) for u, v, attrs in out_edges])
             log.info("Placeholder node \"{}\" was replaced with Const node \"{}\" with value \"{}\"".format(
                 name, new_node.name, value))