1 # ******************************************************************************
2 # Copyright 2017-2020 Intel Corporation
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 # ******************************************************************************
17 from typing import Iterable, Optional
19 from ngraph.impl import Node
22 def get_reduction_axes(node: Node, reduction_axes: Optional[Iterable[int]]) -> Iterable[int]:
23 """! Get reduction axes if it is None and convert it to set if its type is different.
25 If reduction_axes is None we default to reduce all axes.
27 @param node: The node we fill reduction axes for.
28 @param reduction_axes: The collection of indices of axes to reduce. May be None.
29 @return Set filled with indices of axes we want to reduce.
31 if reduction_axes is None:
32 reduction_axes = set(range(len(node.shape)))
34 if type(reduction_axes) is not set:
35 reduction_axes = set(reduction_axes)