"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import logging as log
-import networkx as nx
import numpy as np
from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import erase_node
+from mo.graph.graph import Graph
from mo.middle.passes.convert_data_type import SUPPORTED_DATA_TYPES
from mo.ops.const import Const
from mo.utils.error import Error
class FreezePlaceholderValue(FrontReplacementSubgraph):
"""
- Replaces existing placeholder to Constant node with provided value. It takes value from raplacement_dict as string
- and casts it to actual node data type
- :param replacement_dict: dictionary with node names as keys and strings as values
+ Replaces existing placeholder to Constant node with provided value. It takes value from freeze_placeholder as
+ a string and casts it to actual node data type
"""
+ enabled = True
+ graph_condition = [lambda graph: graph.graph['freeze_placeholder'] is not None]
- enabled = False
- replacement_dict = dict()
+ def run_after(self):
+ from extensions.front.restore_ports import RestorePorts
+ return [RestorePorts]
+
+ def run_before(self):
+ from extensions.front.pass_separator import FrontStart
+ return [FrontStart]
@staticmethod
def pattern():
edges=[]
)
- def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_sub_graph(self, graph: Graph, match: dict):
ph = match['placeholder']
- if ph.name in self.replacement_dict:
+ if ph.name in graph.graph['freeze_placeholder']:
name = ph.name
if ph.has_and_set('data_type'):
data_type = ph.data_type
else:
data_type = SUPPORTED_DATA_TYPES[graph.graph['cmd_params'].data_type][0]
- string_value = self.replacement_dict[name]
+ string_value = graph.graph['freeze_placeholder'][name]
try:
if data_type != np.bool:
value = np.array(string_value, dtype=data_type)
new_node = Const(graph).create_node(
attrs={'value': value, 'data_type': type(value), 'name': name + '/const_placeholder',
'shape': ph.shape})
- erase_node(ph)
+ graph.erase_node(ph)
graph.add_edges_from([(new_node.id, v, attrs) for u, v, attrs in out_edges])
log.info("Placeholder node \"{}\" was replaced with Const node \"{}\" with value \"{}\"".format(
name, new_node.name, value))