Adjust tolerance values in ModelImportRunner to the ones onnxruntime … (#1894)
authorMateusz Tabaka <mateusz.tabaka@intel.com>
Tue, 25 Aug 2020 14:12:55 +0000 (16:12 +0200)
committerGitHub <noreply@github.com>
Tue, 25 Aug 2020 14:12:55 +0000 (16:12 +0200)
ngraph/python/tests/test_onnx/test_additional_models.py

index 43726ff..13d04f9 100644 (file)
@@ -32,12 +32,36 @@ MODELS_ROOT_DIR = tests.ADDITIONAL_MODELS_DIR
 if len(MODELS_ROOT_DIR) == 0:
     MODELS_ROOT_DIR = _get_default_additional_models_dir()
 
+tolerance_map = {
+    "arcface_lresnet100e_opset8": {"atol": 0.001, "rtol": 0.001},
+    "mobilenet_opset7": {"atol": 0.001, "rtol": 0.001},
+    "resnet50_v2_opset7": {"atol": 0.001, "rtol": 0.001},
+    "test_mobilenetv2-1.0": {"atol": 0.001, "rtol": 0.001},
+    "test_resnet101v2": {"atol": 0.001, "rtol": 0.001},
+    "test_resnet18v2": {"atol": 0.001, "rtol": 0.001},
+    "test_resnet34v2": {"atol": 0.001, "rtol": 0.001},
+    "test_resnet50v2": {"atol": 0.001, "rtol": 0.001},
+    "mosaic": {"atol": 0.001, "rtol": 0.001},
+    "pointilism": {"atol": 0.001, "rtol": 0.001},
+    "rain_princess": {"atol": 0.001, "rtol": 0.001},
+    "udnie": {"atol": 0.001, "rtol": 0.001},
+}
+
 zoo_models = []
 # rglob doesn't work for symlinks, so models have to be physically somwhere inside "MODELS_ROOT_DIR"
 for path in Path(MODELS_ROOT_DIR).rglob("*.onnx"):
     mdir, file = os.path.split(str(path))
     if not file.startswith("."):
-        zoo_models.append({"model_name": path, "model_file": file, "dir": str(mdir)})
+        mdir = str(mdir)
+        if mdir.endswith("/"):
+            mdir = mdir[:-1]
+        model = {"model_name": path, "model_file": file, "dir": mdir}
+        basedir = os.path.basename(mdir)
+        if basedir in tolerance_map:
+            # updated model looks now:
+            # {"model_name": path, "model_file": file, "dir": mdir, "atol": ..., "rtol": ...}
+            model.update(tolerance_map[basedir])
+        zoo_models.append(model)
 
 if len(zoo_models) > 0:
     sorted(zoo_models, key=itemgetter("model_name"))