"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import logging as log
-import networkx as nx
import numpy as np
from mo.front.common.layout import get_features_dim, get_height_dim, get_width_dim
from mo.front.common.partial_infer.utils import int64_array
+from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.reshape import Reshape
enabled = True
+ def run_after(self):
+ from extensions.middle.pass_separator import MiddleStart
+ return [MiddleStart]
+
def pattern(self):
return dict(
nodes=[
]
)
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(self, graph: Graph, match: dict):
reshape1 = match['reshape1']
reshape2 = match['reshape2']
transpose = match['transpose']
new_transpose_shape = np.array(new_reshape1_shape[new_transpose_order])
reshape1.out_node().shape = new_reshape1_shape
+ reshape1.dim = np.copy(new_reshape1_shape)
+
transpose.order = new_transpose_order
transpose.out_node().shape = new_transpose_shape
enabled = True
+ def run_before(self):
+ from extensions.middle.pass_separator import MiddleFinish
+ return [MiddleFinish]
+
def pattern(self):
return dict(
nodes=[
('softmax', 'softmax_data'),
])
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(self, graph: Graph, match: dict):
layout = graph.graph['layout']
if layout != 'NHWC':
return