Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / import_extensions.py
index 317bef6..0ed0ce6 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -20,13 +20,23 @@ import os
 import pkgutil
 import sys
 
+from mo.back.replacement import BackReplacementPattern
+from mo.middle.replacement import MiddleReplacementPattern
+from mo.ops.op import Op
+from mo.utils.class_registration import _check_unique_ids, update_registration, get_enabled_and_disabled_transforms
+
 
 def import_by_path(path: str, middle_names: list = ()):
     for module_loader, name, ispkg in pkgutil.iter_modules([path]):
         importlib.import_module('{}.{}'.format('.'.join(middle_names), name))
 
 
-def load_dir(framework: str, path: str, update_registration: callable):
+def default_path():
+    EXT_DIR_NAME = 'extensions'
+    return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, EXT_DIR_NAME))
+
+
+def load_dir(framework: str, path: str, get_front_classes: callable):
     """
     Assuming the following sub-directory structure for path:
 
@@ -57,27 +67,36 @@ def load_dir(framework: str, path: str, update_registration: callable):
     log.info("Importing extensions from: {}".format(path))
     root_dir, ext = os.path.split(path)
     sys.path.insert(0, root_dir)
-    internal_dirs = [['ops', ], ['front', ], ['front', framework], ['middle', ], ['back', ]]
+
+    enabled_transforms, disabled_transforms = get_enabled_and_disabled_transforms()
+
+    front_classes = get_front_classes()
+    internal_dirs = {
+                         ('ops', ): [Op],
+                         ('front', ): front_classes,
+                         ('front', framework): front_classes,
+                         ('middle', ): [MiddleReplacementPattern],
+                         ('back', ): [BackReplacementPattern]}
+
     if ext == 'mo':
-        internal_dirs.append(['front', framework, 'extractors'])
-    for p in internal_dirs:
+        internal_dirs[('front', framework, 'extractors')] = front_classes
+
+    for p in internal_dirs.keys():
         import_by_path(os.path.join(path, *p), [ext, *p])
-        update_registration()
+        update_registration(internal_dirs[p], enabled_transforms, disabled_transforms)
     sys.path.remove(root_dir)
 
 
-def default_path():
-    EXT_DIR_NAME = 'extensions'
-    return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, EXT_DIR_NAME))
-
-
-def load_dirs(framework: str, dirs: list, update_registration: callable):
+def load_dirs(framework: str, dirs: list, get_front_classes: callable):
     if dirs is None:
         return
+
     mo_inner_extensions = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, 'mo'))
     dirs.insert(0, mo_inner_extensions)
     dirs = [os.path.abspath(e) for e in dirs]
     if default_path() not in dirs:
         dirs.insert(0, default_path())
     for path in dirs:
-        load_dir(framework, path, update_registration)
+        load_dir(framework, path, get_front_classes)
+
+    _check_unique_ids()