2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.tf.common import tf_data_type_decode
22 from mo.utils.error import Error
23 from mo.utils.utils import refer_to_faq_msg
26 def tf_tensor_shape(pb):
27 return np.array([dim.size for dim in pb.dim], dtype=np.int64)
31 return np.array(pb.i, dtype=np.int64)
34 def tf_dtype_extractor(pb_dtype, default=None):
35 return tf_data_type_decode[pb_dtype][0] if pb_dtype in tf_data_type_decode else default
38 def tf_data_format_spatial(pb):
40 return [pb.s.index(c) for c in b"DHW"]
41 return [pb.s.index(c) for c in b"HW"]
44 def tf_data_format_channel(pb):
45 return [pb.s.index(b'C')]
48 def tf_data_format_batch(pb):
49 return [pb.s.index(b'N')]
52 def get_tf_node_port(tensor):
54 # tensor should have form 'name:port' or just 'name'
55 name_parts = tensor.split(delim)
56 if len(name_parts) == 1:
57 # just 'name', then port is 0 by default
58 return name_parts[0], 0
60 # 'name:port', note name can contain ':' also but port is the last part
61 # TODO Is 'name' that contains other ':'s considered valid by TF?
62 return delim.join(name_parts[:-1]), int(name_parts[-1])
65 def tf_tensor_content(tf_dtype, shape, pb_tensor):
66 type_helper = tf_data_type_decode[tf_dtype] if tf_dtype in tf_data_type_decode else None
67 if type_helper is None:
68 raise Error("Data type is unsupported: {}. " +
69 refer_to_faq_msg(50), tf_dtype)
71 value = type_helper[1](pb_tensor)
72 value = np.array(value).copy()
73 assert len(value) == 1
74 log.debug("value = {}, shape = {}, res = {}, res.shape = {}".format(str(type_helper[1](pb_tensor)), shape,
75 np.array(type_helper[1](pb_tensor),
76 dtype=type_helper[0]),
77 np.array(type_helper[1](pb_tensor),
78 dtype=type_helper[0]).shape))
79 return np.array(value[0], dtype=type_helper[0])
80 # return np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
82 if pb_tensor.tensor_content:
83 flat = np.array(np.frombuffer(pb_tensor.tensor_content, type_helper[0]))
84 if len(flat) == shape.prod():
85 return flat.reshape(shape)
87 log.warning("Shape and content size of tensor don't match, shape: {} content size: {}".
88 format(shape, len(flat)))
89 # broadcast semantics: no reshape
92 # probably a broadcast semantics
93 # load constant instead of tensor
94 value = np.array(type_helper[1](pb_tensor), dtype=type_helper[0])
95 log.warning("Broadcast of scalar to shape: {}".format(shape))
96 return np.broadcast_to(value, shape=shape).copy()
99 def check_attr_type(a):
101 Check type of attribute from TF prototxt message
102 param: a - attribute from TF prototxt message
103 return: type of attribute
115 if a.shape and a.shape.dim:
121 def collect_tf_attrs(attrs):
123 Function generates map for attributes and parsing functions
124 param: attrs - TF proto message with attributes
125 return: mapping attributes and parsing functions ready for use in update_node_stat function
133 'type': lambda x: tf_dtype_extractor(x.type),
134 'shape': lambda x: tf_tensor_shape(x.shape),
135 'list': lambda x: x.list
139 t = check_attr_type(attrs[a])
142 a_l = type_parsers[t](attrs[a])
143 t = check_attr_type(a_l)
145 ret_attrs[a] = type_parsers[t](a_l)