Support python 3.8 by the Model Optimizer tool in default configuration (#2078)
authorRoman Kazantsev <roman.kazantsev@intel.com>
Wed, 9 Sep 2020 05:34:43 +0000 (08:34 +0300)
committerGitHub <noreply@github.com>
Wed, 9 Sep 2020 05:34:43 +0000 (08:34 +0300)
* Support python 3.8 by the Model Optimizer tool in default configuration

* Fix after review #1

* Fix after the second round review

model-optimizer/mo/utils/versions_checker.py
model-optimizer/mo/utils/versions_checker_test.py
model-optimizer/requirements.txt
model-optimizer/requirements_tf.txt

index d98a8dd..0b53281 100644 (file)
@@ -44,17 +44,24 @@ def check_python_version():
         return 1
 
 
-def parse_versions_list(required_fw_versions, version_list):
+def parse_and_filter_versions_list(required_fw_versions, version_list, env_setup):
     """
     Please do not add parameter type annotations (param:type).
     Because we import this file while checking Python version.
     Python 2.x will fail with no clear message on type annotations.
 
-    Parsing requirements versions
+    Parsing requirements versions for a dependency and filtering out requirements that
+    satisfy environment setup such as python version.
+    if environment version (python_version, etc.) is satisfied
     :param required_fw_versions: String with fw versions from requirements file
     :param version_list: List for append
+    :param env_setup: a dictionary with environment setup
     :return: list of tuples of strings like (name_of_module, sign, version)
 
+    Examples of required_fw_versions:
+    'tensorflow>=1.15.2,<2.0; python_version < "3.8"'
+    'tensorflow>=2.0'
+
     Returned object is:
     [('tensorflow', '>=', '1.2.0'), ('networkx', '==', '2.1'), ('numpy', None, None)]
     """
@@ -62,26 +69,57 @@ def parse_versions_list(required_fw_versions, version_list):
     line = required_fw_versions.strip('\n')
     line = line.strip(' ')
     if line == '':
-        return []
-    splited_versions_by_conditions = re.split(r"==|>=|<=|>|<", line)
+        return version_list
+    splited_requirement = line.split(";")
+
+    # check environment marker
+    if len(splited_requirement) > 1:
+        env_req = splited_requirement[1]
+        splited_env_req = re.split(r"==|>=|<=|>|<", env_req)
+        splited_env_req = [l.strip(',') for l in splited_env_req]
+        env_marker = splited_env_req[0].strip(' ')
+        if env_marker == 'python_version' and env_marker in env_setup:
+            installed_python_version = env_setup['python_version']
+            env_req_version_list = []
+            splited_required_versions = re.split(r",", env_req)
+            for i, l in enumerate(splited_required_versions):
+                for comparison in ['==', '>=', '<=', '<', '>']:
+                    if comparison in l:
+                        required_version = splited_env_req[i + 1].strip(' ').replace('"', '')
+                        env_req_version_list.append((env_marker, comparison, required_version))
+                        break
+            not_satisfied_list = []
+            for name, key, required_version in env_req_version_list:
+                version_check(name, installed_python_version, required_version,
+                              key, not_satisfied_list, 0)
+            if len(not_satisfied_list) > 0:
+                # this python_version requirement is not satisfied to required environment
+                # and requirement for a dependency will be skipped
+                return version_list
+        else:
+            log.error("{} is unsupported environment marker and it will be ignored".format(env_marker),
+                      extra={'is_warning': True})
+
+    # parse a requirement for a dependency
+    requirement = splited_requirement[0]
+    splited_versions_by_conditions = re.split(r"==|>=|<=|>|<", requirement)
     splited_versions_by_conditions = [l.strip(',') for l in splited_versions_by_conditions]
 
     if len(splited_versions_by_conditions) == 0:
-        return []
+        return version_list
     if len(splited_versions_by_conditions) == 1:
         version_list.append((splited_versions_by_conditions[0], None, None))
     else:
-        splited_required_versions= re.split(r",", line)
+        splited_required_versions= re.split(r",", requirement)
         for i, l in enumerate(splited_required_versions):
-            comparisons = ['==', '>=', '<=', '<', '>']
-            for comparison in comparisons:
+            for comparison in ['==', '>=', '<=', '<', '>']:
                 if comparison in l:
                     version_list.append((splited_versions_by_conditions[0], comparison, splited_versions_by_conditions[i + 1]))
                     break
     return version_list
 
 
-def get_module_version_list_from_file(file_name):
+def get_module_version_list_from_file(file_name, env_setup):
     """
     Please do not add parameter type annotations (param:type).
     Because we import this file while checking Python version.
@@ -89,6 +127,7 @@ def get_module_version_list_from_file(file_name):
 
     Reads file with requirements
     :param file_name: Name of the requirements file
+    :param env_setup: a dictionary with environment setup elements
     :return: list of tuples of strings like (name_of_module, sign, version)
 
     File content example:
@@ -102,7 +141,7 @@ def get_module_version_list_from_file(file_name):
     req_dict = list()
     with open(file_name) as f:
         for line in f:
-            req_dict = parse_versions_list(line, req_dict)
+            req_dict = parse_and_filter_versions_list(line, req_dict, env_setup)
     return req_dict
 
 
@@ -113,7 +152,7 @@ def version_check(name, installed_v, required_v, sign, not_satisfied_v, exit_cod
     Python 2.x will fail with no clear message on type annotations.
 
     Evaluates comparison of installed and required versions according to requirements file of one module.
-    If installed version does not satisfy requirements appends this module to not_stisfied_v list.
+    If installed version does not satisfy requirements appends this module to not_satisfied_v list.
     :param name: module name
     :param installed_v: installed version of module
     :param required_v: required version of module
@@ -146,6 +185,25 @@ def version_check(name, installed_v, required_v, sign, not_satisfied_v, exit_cod
     return exit_code
 
 
+def get_environment_setup():
+    """
+    Get environment setup such as Python version, TensorFlow version
+    :return: a dictionary of environment variables
+    """
+    env_setup = dict()
+    python_version = "{}.{}.{}".format(sys.version_info.major,
+                                       sys.version_info.minor,
+                                       sys.version_info.micro)
+    env_setup['python_version'] = python_version
+    try:
+        exec("import tensorflow")
+        env_setup['tensorflow'] = sys.modules["tensorflow"].__version__
+        exec("del tensorflow")
+    except (AttributeError, ImportError):
+        pass
+    return env_setup
+
+
 def check_requirements(framework=None):
     """
     Please do not add parameter type annotations (param:type).
@@ -158,13 +216,20 @@ def check_requirements(framework=None):
     :param framework: framework name
     :return: exit code (0 - execution successful, 1 - error)
     """
+    env_setup = get_environment_setup()
     if framework is None:
         framework_suffix = ""
+    elif framework == "tf":
+        if "tensorflow" in env_setup and env_setup["tensorflow"] >= LooseVersion("2.0.0"):
+            framework_suffix = "_tf2"
+        else:
+            framework_suffix = "_tf"
     else:
         framework_suffix = "_{}".format(framework)
+
     file_name = "requirements{}.txt".format(framework_suffix)
     requirements_file = os.path.realpath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, file_name))
-    requirements_list = get_module_version_list_from_file(requirements_file)
+    requirements_list = get_module_version_list_from_file(requirements_file, env_setup)
     not_satisfied_versions = []
     exit_code = 0
     for name, key, required_version in requirements_list:
index 227b74e..35346d8 100644 (file)
@@ -18,7 +18,7 @@ import unittest
 import unittest.mock as mock
 from unittest.mock import mock_open
 
-from mo.utils.versions_checker import get_module_version_list_from_file, parse_versions_list
+from mo.utils.versions_checker import get_module_version_list_from_file, parse_and_filter_versions_list
 
 
 class TestingVersionsChecker(unittest.TestCase):
@@ -30,18 +30,51 @@ class TestingVersionsChecker(unittest.TestCase):
         ref_list =[('mxnet', '>=', '1.0.0'), ('mxnet', '<=', '1.3.1'),
                           ('networkx', '>=', '1.11'),
                           ('numpy', '==', '1.12.0'), ('defusedxml', '<=', '0.5.0')]
-        version_list = get_module_version_list_from_file('mock_file')
+        version_list = get_module_version_list_from_file('mock_file', {})
         self.assertEqual(len(version_list), 5)
         for i, version_dict in enumerate(version_list):
             self.assertTupleEqual(ref_list[i], version_dict)
 
     @mock.patch('builtins.open', new_callable=mock_open, create=True)
+    def test_get_module_version_list_from_file2(self, mock_open):
+        mock_open.return_value.__enter__ = mock_open
+        mock_open.return_value.__iter__ = mock.Mock(
+            return_value=iter(['tensorflow>=1.15.2,<2.0; python_version < "3.8"',
+                               'tensorflow>=2.0; python_version >= "3.8"',
+                               'numpy==1.12.0',
+                               'defusedxml<=0.5.0']))
+        ref_list =[('tensorflow', '>=', '1.15.2'),
+                   ('tensorflow', '<', '2.0'),
+                   ('numpy', '==', '1.12.0'),
+                   ('defusedxml', '<=', '0.5.0')]
+        version_list = get_module_version_list_from_file('mock_file', {'python_version': '3.7.0'})
+        self.assertEqual(len(version_list), 4)
+        for i, version_dict in enumerate(version_list):
+            self.assertTupleEqual(ref_list[i], version_dict)
+
+    @mock.patch('builtins.open', new_callable=mock_open, create=True)
+    def test_get_module_version_list_from_file3(self, mock_open):
+        mock_open.return_value.__enter__ = mock_open
+        mock_open.return_value.__iter__ = mock.Mock(
+            return_value=iter(['tensorflow>=1.15.2,<2.0; python_version < "3.8"',
+                               'tensorflow>=2.0; python_version >= "3.8"',
+                               'numpy==1.12.0',
+                               'defusedxml<=0.5.0']))
+        ref_list =[('tensorflow', '>=', '2.0'),
+                   ('numpy', '==', '1.12.0'),
+                   ('defusedxml', '<=', '0.5.0')]
+        version_list = get_module_version_list_from_file('mock_file', {'python_version': '3.8.1'})
+        self.assertEqual(len(version_list), 3)
+        for i, version_dict in enumerate(version_list):
+            self.assertTupleEqual(ref_list[i], version_dict)
+
+    @mock.patch('builtins.open', new_callable=mock_open, create=True)
     def test_get_module_version_list_from_file_with_fw_name(self, mock_open):
         mock_open.return_value.__enter__ = mock_open
         mock_open.return_value.__iter__ = mock.Mock(
             return_value=iter(['mxnet']))
         ref_list = [('mxnet', None, None)]
-        version_list = get_module_version_list_from_file('mock_file')
+        version_list = get_module_version_list_from_file('mock_file', {})
         self.assertEqual(len(version_list), 1)
         for i, version_dict in enumerate(version_list):
             self.assertTupleEqual(ref_list[i], version_dict)
@@ -49,7 +82,7 @@ class TestingVersionsChecker(unittest.TestCase):
     def test_append_version_list(self):
         v1 = 'mxnet>=1.0.0,<=1.3.1'
         req_list = list()
-        parse_versions_list(v1, req_list)
+        parse_and_filter_versions_list(v1, req_list, {})
         ref_list = [('mxnet', '>=', '1.0.0'),
                     ('mxnet', '<=', '1.3.1')]
         for i, v in enumerate(req_list):
index e8069df..137b411 100644 (file)
@@ -1,4 +1,5 @@
-tensorflow>=1.15.2,<2.0
+tensorflow>=1.15.2,<2.0; python_version < "3.8"
+tensorflow>=2.0; python_version >= "3.8"
 mxnet>=1.0.0,<=1.5.1
 networkx>=1.11
 numpy>=1.13.0
index ef7e24e..a22cd69 100644 (file)
@@ -1,4 +1,5 @@
-tensorflow>=1.15.2,<2.0
+tensorflow>=1.15.2,<2.0; python_version < "3.8"
+tensorflow>=2.0; python_version >= "3.8"
 networkx>=1.11
 numpy>=1.13.0
 test-generator==0.1.1