2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 def space_to_batch_infer(node):
22 https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch
24 input_shape = node.in_node(0).shape
25 if input_shape is None:
28 if len(node.in_nodes()) != 3:
31 if node.in_node(1).value is None or node.in_node(2).value is None:
34 block_size = node.in_node(1).value
35 pad = node.in_node(2).value
37 pads = pad[:, 0] + input_shape[1:len(block_size)+1] + pad[:, 1]
39 output_shape = [input_shape[0] * np.prod(block_size), *[int(x) for x in (pads / block_size)], input_shape[-1]]
40 node.out_node().shape = np.array(output_shape)
43 def batch_to_space_infer(node):
45 https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch
47 input_shape = node.in_node(0).shape
48 if input_shape is None:
51 if len(node.in_nodes()) != 3:
54 if node.in_node(1).value is None or node.in_node(2).value is None:
57 block_size = node.in_node(1).value
58 crop = node.in_node(2).value
60 pads = block_size * input_shape[1:len(block_size)+1]
62 sizes = pads - crop[:, 0] - crop[:, 1]
63 batch = int(input_shape[0] / (np.prod(block_size)))
65 output_shape = [batch, *sizes, input_shape[-1]]
66 node.out_node().shape = np.array(output_shape)