"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
"""
import math
-import networkx as nx
import numpy as np
from mo.front.common.partial_infer.elemental import single_output_infer
from mo.front.common.partial_infer.reshape import tf_reshape_shape_infer
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.ops.op import Op
from mo.utils.error import Error
op = 'Reshape'
enabled = True
- def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+ def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'kind': 'op',
'type': __class__.op,
'op': __class__.op,
+ 'in_ports_count': 2,
+ 'out_ports_count': 1,
'infer': lambda node: single_output_infer(node, tf_reshape_shape_infer,
lambda node: np.reshape(node.in_node().value,
node.out_node().shape))
}, attrs)
- def supported_attrs(self):
- return [('dim', lambda node: ','.join(map(str, node['dim'])))]
-
@staticmethod
def kaldi_infer(node: Node):
in_node = node.in_node().in_node() # prev_layer_node -> data -> this_node
# Convolution/Pooling layers. Therefore there are 4 cases with different
# partial inference.
batch = input_shape[0]
- if in_node.op == 'Convolution' or in_node.op == 'Pooling':
+ if in_node.op in ['Convolution', 'Pooling', 'Permute']:
output_spatial = np.array([batch, np.prod(input_shape[1:])], dtype=np.int64)
return Reshape.set_shape_and_dim(node, output_spatial)
# Supports ONLY NCHW and NH layouts