def get_reduction_axes(node: Node, reduction_axes: Optional[Iterable[int]]) -> Iterable[int]:
- """Get reduction axes if it is None and convert it to set if its type is different.
+ """! Get reduction axes if it is None and convert it to set if its type is different.
If reduction_axes is None we default to reduce all axes.
- :param node: The node we fill reduction axes for.
- :param reduction_axes: The collection of indices of axes to reduce. May be None.
- :return: Set filled with indices of axes we want to reduce.
+ @param node: The node we fill reduction axes for.
+ @param reduction_axes: The collection of indices of axes to reduce. May be None.
+ @return Set filled with indices of axes we want to reduce.
"""
if reduction_axes is None:
reduction_axes = set(range(len(node.shape)))