2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.front.common.partial_infer.utils import tf_window_op_pad_infer
20 from mo.front.extractor import attr_getter
21 # from mo.front.common.partial_infer.pooling import pool_explicit_padding_infer
22 from mo.front.extractor import spatial_getter
23 from mo.front.onnx.extractors.utils import get_backend_pad
24 from mo.graph.graph import Node, Graph
25 from mo.ops.op import Op, PermuteAttrs
31 def __init__(self, graph: Graph, attrs: dict):
32 super().__init__(graph, {
36 'infer': __class__.infer,
41 def backend_attrs(self):
43 ('strides', lambda node: ','.join(map(str, node['stride'][node.spatial_dims]))),
44 ('kernel', lambda node: ','.join(map(str, node['window'][node.spatial_dims]))),
46 ('pads_begin', lambda node: ','.join(map(str, get_backend_pad(node.pad, node.spatial_dims, 0)))),
47 ('pads_end', lambda node: ','.join(map(str, get_backend_pad(node.pad, node.spatial_dims, 1)))),
49 ('pool-method', 'pool_method'),
50 ('exclude-pad', 'exclude_pad'),
56 def backend_attrs_v2(self):
58 ('stride', lambda node: attr_getter(node, 'stride')),
60 spatial_getter('stride-x', 'stride', 1),
61 spatial_getter('stride-y', 'stride', 0),
62 spatial_getter('kernel-x', 'window', 1),
63 spatial_getter('kernel-y', 'window', 0),
64 spatial_getter('pad-x', 'pad', 1, lambda x: x[0]),
65 spatial_getter('pad-y', 'pad', 0, lambda x: x[0]),
67 ('pool-method', 'pool_method'),
68 ('exclude-pad', 'exclude_pad'),
75 def infer(node: Node):
76 assert (len(node.in_nodes()) == 1)
77 input_shape = node.in_node(0).shape
78 if input_shape is None:
81 if not node.has_valid('spatial_dims'):
82 node['spatial_dims'] = np.delete([x for x in range(len(input_shape))],
83 [node.batch_dims[0], node.channel_dims[0]])
85 input_spatial_shape = input_shape[node.spatial_dims]
87 # Setting default pad and stride attrs in case of None specified
88 if not node.has_valid('pad'):
89 node['pad'] = np.array([[0, 0] for x in range(len(input_shape))], dtype=np.int64)
90 if not node.has_valid('pad_spatial_shape'):
91 node['pad_spatial_shape'] = node.pad[node.spatial_dims]
92 if not node.has_valid('stride'):
93 node['stride'] = np.array([1 for x in range(len(input_shape))], dtype=np.int64)
95 if node.has_and_set('global_pool'):
96 node['window'] = np.zeros(len(input_shape), dtype=np.int64)
97 node.window[node.spatial_dims] = input_spatial_shape
99 window_spatial_shape = node.window[node.spatial_dims]
100 stride_spatial = node.stride[node.spatial_dims]
101 assert any(stride_spatial), 'Stride can not be zero in node {}'.format(node.id)
103 if node.has_valid('auto_pad'):
104 node.pad_spatial_shape, node.output_spatial_shape = tf_window_op_pad_infer(input_spatial_shape,
105 window_spatial_shape,
106 stride_spatial, node.auto_pad)
107 pad = np.zeros((len(input_shape), 2), dtype=np.int64)
108 pad[node.spatial_dims] = node.pad_spatial_shape
112 pad_spatial_shape = np.add.reduce(node.pad_spatial_shape, axis=1)
115 if node.has_valid('pooling_convention') and node.pooling_convention == 'full':
117 output_spatial_shape = np.array(rounding(
118 np.array(input_spatial_shape + pad_spatial_shape - window_spatial_shape,
119 dtype=np.float) / stride_spatial),
122 original_pads = np.array([i[1] for i in node.pad_spatial_shape])
124 for i in range(len(input_spatial_shape)):
125 if original_pads[i] and (output_spatial_shape[i] - 1) * stride_spatial[i] >= \
126 input_spatial_shape[i] + original_pads[i]:
127 output_spatial_shape[i] -= 1
129 node['output_spatial_shape'] = output_spatial_shape
131 output_shape = input_shape.copy()
132 output_shape[node.spatial_dims] = node.output_spatial_shape
133 node.out_node().shape = output_shape
136 PermuteAttrs.create_permute_attrs(node, attrs=[('pad', 'input:0'),
137 ('stride', 'input:0'),
138 ('window', 'input:0'),
139 ('spatial_dims', 'input:0')])