Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / tf / extractors / utils.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.front.tf.common import tf_data_type_decode
22 from mo.utils.error import Error
23 from mo.utils.utils import refer_to_faq_msg
24
25
26 def tf_tensor_shape(pb):
27     return np.array([dim.size for dim in pb.dim], dtype=np.int64)
28
29
30 def tf_int_list(pb):
31     return np.array(pb.i, dtype=np.int64)
32
33
34 def tf_dtype_extractor(pb_dtype, default=None):
35     return tf_data_type_decode[pb_dtype][0] if pb_dtype in tf_data_type_decode else default
36
37
38 def tf_data_format_spatial(pb):
39     if b"DHW" in pb.s:
40         return [pb.s.index(c) for c in b"DHW"]
41     return [pb.s.index(c) for c in b"HW"]
42
43
44 def tf_data_format_channel(pb):
45     return [pb.s.index(b'C')]
46
47
48 def tf_data_format_batch(pb):
49     return [pb.s.index(b'N')]
50
51
52 def get_tf_node_port(tensor):
53     delim = ':'
54     # tensor should have form 'name:port' or just 'name'
55     name_parts = tensor.split(delim)
56     if len(name_parts) == 1:
57         # just 'name', then port is 0 by default
58         return name_parts[0], 0
59     else:
60         # 'name:port', note name can contain ':' also but port is the last part
61         # TODO Is 'name' that contains other ':'s considered valid by TF?
62         return delim.join(name_parts[:-1]), int(name_parts[-1])
63
64
65 def tf_tensor_content(tf_dtype, shape, pb_tensor):
66     type_helper = tf_data_type_decode[tf_dtype] if tf_dtype in tf_data_type_decode else None
67     if type_helper is None:
68         raise Error("Data type is unsupported: {}. " +
69                     refer_to_faq_msg(50), tf_dtype)
70     if len(shape) == 0:
71         value = type_helper[1](pb_tensor)
72         value = np.array(value).copy()
73         assert len(value) == 1
74         log.debug("value = {}, shape = {}, res = {}, res.shape = {}".format(str(type_helper[1](pb_tensor)), shape,
75                                                                             np.array(type_helper[1](pb_tensor),
76                                                                                      dtype=type_helper[0]),
77                                                                             np.array(type_helper[1](pb_tensor),
78                                                                                      dtype=type_helper[0]).shape))
79         return np.array(value[0], dtype=type_helper[0])
80         # return np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
81     else:
82         if pb_tensor.tensor_content:
83             flat = np.array(np.frombuffer(pb_tensor.tensor_content, type_helper[0]))
84             if len(flat) == shape.prod():
85                 return flat.reshape(shape)
86             else:
87                 log.warning("Shape and content size of tensor don't match, shape: {} content size: {}".
88                             format(shape, len(flat)))
89                 # broadcast semantics: no reshape
90                 return flat
91         else:
92             # probably a broadcast semantics
93             # load constant instead of tensor
94             value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
95             log.warning("Broadcast of scalar to shape: {}".format(shape))
96             return np.broadcast_to(value, shape=shape).copy()
97
98
99 def check_attr_type(a):
100     """
101       Check type of attribute from TF prototxt message
102       param: a - attribute from TF prototxt message
103       return: type of attribute
104     """
105     if a.s:
106         return 's'
107     if a.i:
108         return 'i'
109     if a.f:
110         return 'f'
111     if a.b:
112         return 'b'
113     if a.type:
114         return 'type'
115     if a.shape and a.shape.dim:
116         return 'shape'
117     if a.list:
118         return 'list'
119
120
121 def collect_tf_attrs(attrs):
122     """
123      Function generates map for attributes and parsing functions
124      param: attrs  - TF proto message with attributes
125      return: mapping attributes and parsing functions ready for use in update_node_stat function
126     """
127     ret_attrs = {}
128     type_parsers = {
129         's': lambda x: x.s,
130         'i': lambda x: x.i,
131         'f': lambda x: x.f,
132         'b': lambda x: x.b,
133         'type': lambda x: tf_dtype_extractor(x.type),
134         'shape': lambda x: tf_tensor_shape(x.shape),
135         'list': lambda x: x.list
136     }
137
138     for a in attrs:
139         t = check_attr_type(attrs[a])
140         a_l = attrs[a]
141         while t == 'list':
142             a_l = type_parsers[t](attrs[a])
143             t = check_attr_type(a_l)
144
145         ret_attrs[a] = type_parsers[t](a_l)
146
147     return ret_attrs