"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import pkgutil
import sys
+from mo.back.replacement import BackReplacementPattern
+from mo.middle.replacement import MiddleReplacementPattern
+from mo.ops.op import Op
+from mo.utils.class_registration import _check_unique_ids, update_registration, get_enabled_and_disabled_transforms
+
def import_by_path(path: str, middle_names: list = ()):
for module_loader, name, ispkg in pkgutil.iter_modules([path]):
importlib.import_module('{}.{}'.format('.'.join(middle_names), name))
-def load_dir(framework: str, path: str, update_registration: callable):
+def default_path():
+ EXT_DIR_NAME = 'extensions'
+ return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, EXT_DIR_NAME))
+
+
+def load_dir(framework: str, path: str, get_front_classes: callable):
"""
Assuming the following sub-directory structure for path:
log.info("Importing extensions from: {}".format(path))
root_dir, ext = os.path.split(path)
sys.path.insert(0, root_dir)
- internal_dirs = [['ops', ], ['front', ], ['front', framework], ['middle', ], ['back', ]]
+
+ enabled_transforms, disabled_transforms = get_enabled_and_disabled_transforms()
+
+ front_classes = get_front_classes()
+ internal_dirs = {
+ ('ops', ): [Op],
+ ('front', ): front_classes,
+ ('front', framework): front_classes,
+ ('middle', ): [MiddleReplacementPattern],
+ ('back', ): [BackReplacementPattern]}
+
if ext == 'mo':
- internal_dirs.append(['front', framework, 'extractors'])
- for p in internal_dirs:
+ internal_dirs[('front', framework, 'extractors')] = front_classes
+
+ for p in internal_dirs.keys():
import_by_path(os.path.join(path, *p), [ext, *p])
- update_registration()
+ update_registration(internal_dirs[p], enabled_transforms, disabled_transforms)
sys.path.remove(root_dir)
-def default_path():
- EXT_DIR_NAME = 'extensions'
- return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, EXT_DIR_NAME))
-
-
-def load_dirs(framework: str, dirs: list, update_registration: callable):
+def load_dirs(framework: str, dirs: list, get_front_classes: callable):
if dirs is None:
return
+
mo_inner_extensions = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, 'mo'))
dirs.insert(0, mo_inner_extensions)
dirs = [os.path.abspath(e) for e in dirs]
if default_path() not in dirs:
dirs.insert(0, default_path())
for path in dirs:
- load_dir(framework, path, update_registration)
+ load_dir(framework, path, get_front_classes)
+
+ _check_unique_ids()