"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import logging as log
-import networkx as nx
import numpy as np
from extensions.middle.ConvertGroupedStridedSlice import ConvertGroupedStridedSlice
from extensions.middle.SliceConverter import ConvertSlice
+from mo.graph.graph import Graph
from mo.middle.passes.eliminate import remove_op_node_with_data_node
from mo.middle.replacement import MiddleReplacementPattern
edges=[]
)
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(self, graph: Graph, match: dict):
output_data_node = match['strided_slice'].out_node(0)
input_data_node = match['strided_slice'].in_node(0)
if np.array_equal(input_data_node.shape, output_data_node.shape) and \
# remove inputs to Strided Slice so it has just one input with data so we can use 'remove_op_node' function
graph.remove_edge(match['strided_slice'].in_node(1).id, match['strided_slice'].id)
graph.remove_edge(match['strided_slice'].in_node(2).id, match['strided_slice'].id)
- graph.remove_edge(match['strided_slice'].in_node(3).id, match['strided_slice'].id)
+ if len(match['strided_slice'].in_nodes()) > 3:
+ graph.remove_edge(match['strided_slice'].in_node(3).id, match['strided_slice'].id)
remove_op_node_with_data_node(graph, match['strided_slice'])