"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
-
-from extensions.middle.NormalizeFullyConnected import NormalizeFullyConnected
from mo.front.common.partial_infer.utils import mark_input_bins, assign_dims_to_weights, int64_array
+from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.op import PermuteAttrs
class GemmResolver(MiddleReplacementPattern):
enabled = True
+ graph_condition = [lambda graph: graph.graph['fw'] != 'tf']
def run_before(self):
+ from extensions.middle.NormalizeFullyConnected import NormalizeFullyConnected
return [NormalizeFullyConnected]
+ def run_after(self):
+ from extensions.middle.pass_separator import MiddleStart
+ return [MiddleStart]
+
def pattern(self):
return dict(
nodes=[
- ('input_0', dict(kind='data')),
- ('input_1', dict(kind='data')),
- ('fc', dict(op='MatMul')),
- ('fc_data', dict(kind='data'))],
+ ('input_0', dict(kind='data')),
+ ('input_1', dict(kind='data')),
+ ('fc', dict(op='MatMul')),
+ ('fc_data', dict(kind='data'))],
edges=[
('input_0', 'fc', {'in': 0}),
('input_1', 'fc', {'in': 1}),
]
)
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(self, graph: Graph, match: dict):
if not match['input_0'].has_valid('value') and not match['input_1'].has_valid('value') or \
- not match['input_0'].has_valid('value') and match['input_1'].has_valid('value') and match['input_1'].shape.size > 2:
+ not match['input_0'].has_valid('value') and match['input_1'].has_valid('value') and match[
+ 'input_1'].shape.size > 2:
match['fc']['type'] = 'GEMM'
elif not match['input_0'].has_valid('value') and match['input_1'].has_valid('value'):
match['fc']['type'] = 'FullyConnected'
weights_shape = weights_node.shape
node['out-size'] = weights_shape[1]
-
-
-