[ MO ONNX ] Resize-11 clear error message (#620)
authorEvgenya Stepyreva <evgenya.stepyreva@intel.com>
Wed, 27 May 2020 05:09:15 +0000 (08:09 +0300)
committerGitHub <noreply@github.com>
Wed, 27 May 2020 05:09:15 +0000 (08:09 +0300)
* Small refactoring of extractors

* [ MO ] Throwing an exception while extracting Resize-11 which is not supported

model-optimizer/extensions/front/onnx/resize_ext.py
model-optimizer/extensions/front/onnx/reverse_sequence_ext.py
model-optimizer/extensions/front/tf/bucketize_ext.py
model-optimizer/extensions/front/tf/sparse_to_dense_ext.py
model-optimizer/mo/front/kaldi/extractors/linear_component_ext.py

index f9e232e..c8aefdb 100644 (file)
 
 from extensions.ops.upsample import UpsampleOp
 from mo.front.extractor import FrontExtractorOp
-from mo.front.onnx.extractors.utils import onnx_attr
+from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
 from mo.graph.graph import Node
+from mo.utils.error import Error
 
 
 class ResizeExtractor(FrontExtractorOp):
     op = 'Resize'
     enabled = True
 
-    @staticmethod
-    def extract(node: Node):
+    @classmethod
+    def extract(cls, node: Node):
+        onnx_opset_version = get_onnx_opset_version(node)
+        if onnx_opset_version is not None and onnx_opset_version >= 11:
+            raise Error("ONNX Resize operation from opset {} is not supported.".format(onnx_opset_version))
         mode = onnx_attr(node, 'mode', 's', default=b'nearest').decode()
         UpsampleOp.update_node_stat(node, {'mode': mode})
-        return __class__.enabled
+        return cls.enabled
index 48cf743..c8e4313 100644 (file)
@@ -23,8 +23,8 @@ class ReverseSequenceExtractor(FrontExtractorOp):
     op = 'ReverseSequence'
     enabled = True
 
-    @staticmethod
-    def extract(node):
+    @classmethod
+    def extract(cls, node):
         batch_axis = onnx_attr(node, 'batch_axis', 'i', default=1)
         time_axis = onnx_attr(node, 'time_axis', 'i', default=0)
 
@@ -33,4 +33,4 @@ class ReverseSequenceExtractor(FrontExtractorOp):
             'seq_axis': time_axis,
         }
         ReverseSequence.update_node_stat(node, attrs)
-        return __class__.enabled
+        return cls.enabled
index 1a17b39..595c9d6 100644 (file)
@@ -24,9 +24,8 @@ class BucketizeFrontExtractor(FrontExtractorOp):
     op = 'Bucketize'
     enabled = True
 
-    @staticmethod
-    def extract(node):
+    @classmethod
+    def extract(cls, node):
         boundaries = np.array(node.pb.attr['boundaries'].list.f, dtype=np.float)
         Bucketize.update_node_stat(node, {'boundaries': boundaries, 'with_right_bound': False})
-
-        return __class__.enabled
+        return cls.enabled
index b331775..c080294 100644 (file)
@@ -22,8 +22,7 @@ class SparseToDenseFrontExtractor(FrontExtractorOp):
     op = 'SparseToDense'
     enabled = True
 
-    @staticmethod
-    def extract(node):
+    @classmethod
+    def extract(cls, node):
         SparseToDense.update_node_stat(node)
-
-        return __class__.enabled
+        return cls.enabled
index e1a7a5e..6453655 100644 (file)
@@ -25,8 +25,8 @@ class LinearComponentFrontExtractor(FrontExtractorOp):
     op = 'linearcomponent'
     enabled = True
 
-    @staticmethod
-    def extract(node):
+    @classmethod
+    def extract(cls, node):
         pb = node.parameters
         collect_until_token(pb, b'<Params>')
         weights, weights_shape = read_binary_matrix(pb)
@@ -39,4 +39,4 @@ class LinearComponentFrontExtractor(FrontExtractorOp):
         embed_input(mapping_rule, 1, 'weights', weights)
 
         FullyConnected.update_node_stat(node, mapping_rule)
-        return __class__.enabled
+        return cls.enabled