"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import logging as log
-import networkx as nx
import numpy as np
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.ops.op import Op
op = 'Slice'
enabled = True
- def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+ def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': __class__.op,
'op': 'Slice',
+ 'in_ports_count': 3,
+ 'out_ports_count': 1,
'infer': __class__.infer
}, attrs)
+ def supported_attrs(self):
+ return ['start', 'end', 'axis']
+
@staticmethod
def infer(node: Node):
if len(node.in_nodes()) == 1:
from mo.front.common.partial_infer.slice import caffe_slice_infer
caffe_slice_infer(node)
elif len(node.in_nodes()) == 3:
- #TF case
+ # TF case
start_node = node.in_node(1)
size_node = node.in_node(2)
if start_node.has_valid('value') and size_node.has_valid('value'):
if s is None:
slice_idx[axis] = slice(0, input_shape[axis], 1)
- #Add new parameters to node
+ # Add new parameters to node
node['slices'] = np.array(slice_idx)
node['shrink_axis_mask'] = np.array(shrink_axis_mask)
- value = value[slice_idx]
+ value = value[tuple(slice_idx)]
node.out_node().value = np.array(value) if node.in_node(0).value is not None else None
node.out_node().shape = np.array(value.shape)