Fix skipping incorrect names in scale/mean values (#535)
authorPavel Esir <pavel.esir@gmail.com>
Wed, 27 May 2020 11:53:50 +0000 (14:53 +0300)
committerGitHub <noreply@github.com>
Wed, 27 May 2020 11:53:50 +0000 (14:53 +0300)
* Fix skipping incorrect names in scale/mean values

* removed inappropriate comment in cli_parser.py

model-optimizer/mo/utils/cli_parser.py
model-optimizer/mo/utils/cli_parser_test.py

index 9dafc6d..a6f649f 100644 (file)
@@ -1011,7 +1011,7 @@ def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
     Returns
     -------
     The function returns a dictionary e.g.
-    mean = { 'data: np.array, 'info': np.array }, scale = { 'data: np.array, 'info': np.array }, input = "data, info" ->
+    mean = { 'data': np.array, 'info': np.array }, scale = { 'data': np.array, 'info': np.array }, input = "data, info" ->
      { 'data': { 'mean': np.array, 'scale': np.array }, 'info': { 'mean': np.array, 'scale': np.array } }
 
     """
@@ -1032,6 +1032,17 @@ def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
     if type(mean_values) is dict and type(scale_values) is dict:
         if not mean_values and not scale_values:
             return res
+
+        for inp_scale in scale_values.keys():
+            if inp_scale not in inputs:
+                raise Error("Specified scale_values name '{}' do not match to any of inputs: {}. "
+                            "Please set 'scale_values' that correspond to values from input.".format(inp_scale, inputs))
+
+        for inp_mean in mean_values.keys():
+            if inp_mean not in inputs:
+                raise Error("Specified mean_values name '{}' do not match to any of inputs: {}. "
+                            "Please set 'mean_values' that correspond to values from input.".format(inp_mean, inputs))
+
         for inp in inputs:
             inp, port = split_node_in_port(inp)
             if inp in mean_values or inp in scale_values:
@@ -1105,7 +1116,7 @@ def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
                     }
                 )
             return res
-    # mean and scale are specified without inputs, return list, order is not guaranteed (?)
+    # mean and/or scale are specified without inputs
     return list(zip_longest(mean_values, scale_values))
 
 
index b3a2cb3..deefa5a 100644 (file)
@@ -345,6 +345,15 @@ class TestingMeanScaleGetter(unittest.TestCase):
             for j in range(0, len(exp_res[i])):
                 np.array_equal(exp_res[i][j], result[i][j])
 
+    def test_scale_do_not_match_input(self):
+        scale_values = parse_tuple_pairs("input_not_present(255),input2(255)")
+        mean_values = parse_tuple_pairs("input1(255),input2(255)")
+        self.assertRaises(Error, get_mean_scale_dictionary, mean_values, scale_values, "input1,input2")
+
+    def test_mean_do_not_match_input(self):
+        scale_values = parse_tuple_pairs("input1(255),input2(255)")
+        mean_values = parse_tuple_pairs("input_not_present(255),input2(255)")
+        self.assertRaises(Error, get_mean_scale_dictionary, mean_values, scale_values, "input1,input2")
 
 class TestSingleTupleParsing(unittest.TestCase):
     def test_get_values_ideal(self):