2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.utils.error import Error
23 nchw_to_nhwc_permute = np.array([0, 2, 3, 1], dtype=np.int64)
24 nhwc_to_nchw_permute = np.array([0, 3, 1, 2], dtype=np.int64)
25 supported_layouts = ('NCHW', 'NHWC')
26 # the attribute 'layout' in the graph.graph can have two values only: "NCHW" or "NHWC". If the tensor has 5 dimensions
27 # then it is necessary to transform "NCHW" to "NCDHW" and "NHWC" to "NDHWC" respectively. The dictionary below id used
29 indices_mapping = {4: {'NCHW': 'NCHW',
35 def convert_shape(shape: np.array, permute: np.array):
37 for ind, perm_ind in enumerate(permute):
38 result[ind] = shape[perm_ind]
39 return np.array(result)
42 def get_depth_dim(layout: str, shape_len: int):
44 Gets index of the dimension corresponding to depth.
45 :param layout: string representing layout: NCHW or NHWC usually.
46 :param shape_len: the shape length.
47 :return: index of the 'D' character
49 assert layout in supported_layouts
51 return indices_mapping[shape_len][layout].find('D')
54 def get_height_dim(layout: str, shape_len: int):
56 Gets index of the dimension corresponding to height.
57 :param layout: string representing layout: NCHW or NHWC usually.
58 :param shape_len: the shape length.
59 :return: index of the 'H' character
61 assert layout in supported_layouts
62 assert 4 <= shape_len <= 5
63 return indices_mapping[shape_len][layout].find('H')
66 def get_width_dim(layout: str, shape_len: int):
68 Gets index of the dimension corresponding to width.
69 :param layout: string representing layout: NCHW or NHWC usually.
70 :param shape_len: the shape length.
71 :return: index of the 'W' character
73 assert layout in supported_layouts
74 assert 4 <= shape_len <= 5
75 return indices_mapping[shape_len][layout].find('W')
78 def get_features_dim(layout: str, shape_len: int):
80 Gets index of the dimension corresponding to features.
81 :param layout: string representing layout: NCHW or NHWC usually.
82 :param shape_len: the shape length.
83 :return: index of the 'C' character
85 assert layout in supported_layouts
86 assert 4 <= shape_len <= 5
87 return indices_mapping[shape_len][layout].find('C')
90 def get_batch_dim(layout: str, shape_len: int):
92 Gets index of the dimension corresponding to batch.
93 :param layout: string representing layout: NCHW or NHWC usually.
94 :param shape_len: the shape length.
95 :return: index of the 'N' character
97 assert layout in supported_layouts
98 assert 4 <= shape_len <= 5
99 return indices_mapping[shape_len][layout].find('N')
102 def shape_for_layout(layout: str, **kwargs):
104 Creates 4D or 5D tensor with the layout with specified dimension sizes.
105 :param layout: layout string.
106 :param kwargs: dictionary that contains the dimension sizes using the following keys: 'batch', 'features', 'depth',
108 :return: np.array of type np.int64 with 4 or 5 elements.
110 assert layout in supported_layouts
111 for required_key in ('batch', 'features', 'height', 'width'):
112 if required_key not in kwargs:
113 raise Error('Required parameter "{}" is missing.'.format(required_key))
114 for key in kwargs.keys():
115 if key not in ('batch', 'features', 'height', 'width', 'depth'):
116 raise Error('Parameter "{}" is not supported.'.format(key))
118 depth = kwargs.get('depth', None)
119 shape_len = 4 + (depth is not None)
120 output_shape = np.ones(shape=[shape_len], dtype=np.int64)
121 output_shape[get_batch_dim(layout, shape_len)] = kwargs['batch']
122 output_shape[get_height_dim(layout, shape_len)] = kwargs['height']
123 output_shape[get_width_dim(layout, shape_len)] = kwargs['width']
124 output_shape[get_features_dim(layout, shape_len)] = kwargs['features']
125 if depth is not None:
126 output_shape[get_depth_dim(layout, shape_len)] = depth