97197da063e910d9d5b06b98adc45d5601114159
[platform/upstream/dldt.git] / ngraph / python / src / ngraph / utils / reduction.py
1 # ******************************************************************************
2 # Copyright 2017-2020 Intel Corporation
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 # ******************************************************************************
16
17 from typing import Iterable, Optional
18
19 from ngraph.impl import Node
20
21
22 def get_reduction_axes(node: Node, reduction_axes: Optional[Iterable[int]]) -> Iterable[int]:
23     """! Get reduction axes if it is None and convert it to set if its type is different.
24
25     If reduction_axes is None we default to reduce all axes.
26
27     @param node: The node we fill reduction axes for.
28     @param reduction_axes: The collection of indices of axes to reduce. May be None.
29     @return Set filled with indices of axes we want to reduce.
30     """
31     if reduction_axes is None:
32         reduction_axes = set(range(len(node.shape)))
33
34     if type(reduction_axes) is not set:
35         reduction_axes = set(reduction_axes)
36     return reduction_axes