Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / tf / extractors / strided_slice.py
index cc2ecd2..909c10f 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  See the License for the specific language governing permissions and
  limitations under the License.
 """
+import numpy as np
 
-from mo.front.common.partial_infer.slice import tf_strided_slice_infer
+from mo.front.extractor import FrontExtractorOp
+from mo.ops.op import Op
 
 
-def tf_strided_slice_ext(pb):
-    return {
-        'begin_mask': pb.attr["begin_mask"].i,
-        'end_mask': pb.attr["end_mask"].i,
-        'ellipsis_mask': pb.attr["ellipsis_mask"].i,
-        'new_axis_mask': pb.attr["new_axis_mask"].i,
-        'shrink_axis_mask': pb.attr["shrink_axis_mask"].i,
-        'infer': tf_strided_slice_infer
-    }
+def int_to_array_bit_mask(im):
+    list_repr = list(np.binary_repr(im))
+    list_repr.reverse()
+    list_repr = [int(li) for li in list_repr]
+    return np.array(list_repr, dtype=np.int32)
+
+
+class StridedSliceFrontExtractor(FrontExtractorOp):
+    op = 'StridedSlice'
+    enabled = True
+
+    @staticmethod
+    def extract(node):
+        pb = node.pb
+        bm = int_to_array_bit_mask(pb.attr["begin_mask"].i)
+        bm = np.array([1 - b for b in bm], dtype=np.int32)
+        em = int_to_array_bit_mask(pb.attr["end_mask"].i)
+        em = np.array([1 - b for b in em], dtype=np.int32)
+        attrs = {
+            'begin_mask': bm,
+            'end_mask': em,
+            'ellipsis_mask': int_to_array_bit_mask(pb.attr["ellipsis_mask"].i),
+            'new_axis_mask': int_to_array_bit_mask(pb.attr["new_axis_mask"].i),
+            'shrink_axis_mask': int_to_array_bit_mask(pb.attr["shrink_axis_mask"].i),
+        }
+
+        Op.get_op_class_by_name(__class__.op).update_node_stat(node, attrs)
+        return __class__.enabled