Added more version comparison operations (#63848)
authorNived P A <62166124+nivedwho@users.noreply.github.com>
Thu, 9 Sep 2021 17:29:10 +0000 (10:29 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 17:30:20 +0000 (10:30 -0700)
Summary:
Currently the [TorchVersion](https://github.com/pytorch/pytorch/blob/1022443168b5fad55bbd03d087abf574c9d2e9df/torch/torch_version.py#L13) only only supports 'greater than', and 'equal to' operations for comparing torch versions and something like `TorchVersion('1.5.0') < (1,5,1)` or `TorchVersion('1.5.0') >= (1,5)` will throw an error.

I have added 'less than' (`__lt__()`), 'greater than or equal to' (`__ge__()`) and 'less than or equal to' (`__le__()`) operations, so that the TorchVersion object can be useful for wider range of version comparisons.

cc seemethere zsol

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63848

Reviewed By: fmassa, heitorschueroff

Differential Revision: D30526996

Pulled By: seemethere

fbshipit-source-id: 1db6bee555043e0719fd541cec27810852590940

torch/torch_version.py

index 4998c55..f840550 100644 (file)
@@ -1,4 +1,3 @@
-from functools import total_ordering
 from typing import Iterable, Union
 
 from pkg_resources import packaging  # type: ignore[attr-defined]
@@ -9,24 +8,19 @@ InvalidVersion = packaging.version.InvalidVersion
 from .version import __version__ as internal_version
 
 
-@total_ordering
 class TorchVersion(str):
     """A string with magic powers to compare to both Version and iterables!
-
     Prior to 1.10.0 torch.__version__ was stored as a str and so many did
     comparisons against torch.__version__ as if it were a str. In order to not
     break them we have TorchVersion which masquerades as a str while also
     having the ability to compare against both packaging.version.Version as
     well as tuples of values, eg. (1, 2, 1)
-
     Examples:
         Comparing a TorchVersion object to a Version object
             TorchVersion('1.10.0a') > Version('1.10.0a')
-
         Comparing a TorchVersion object to a Tuple object
             TorchVersion('1.10.0a') > (1, 2)    # 1.2
             TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1
-
         Comparing a TorchVersion object against a string
             TorchVersion('1.10.0a') > '1.2'
             TorchVersion('1.10.0a') > '1.2.1'
@@ -56,6 +50,13 @@ class TorchVersion(str):
             # version like 'parrot'
             return super().__gt__(cmp)
 
+    def __lt__(self, cmp):
+        try:
+            return Version(self).__lt__(self._convert_to_version(cmp))
+        except InvalidVersion:
+            # Fall back to regular string comparison if dealing with an invalid
+            # version like 'parrot'
+            return super().__lt__(cmp)
 
     def __eq__(self, cmp):
         try:
@@ -65,7 +66,6 @@ class TorchVersion(str):
             # version like 'parrot'
             return super().__eq__(cmp)
 
-
     def __ge__(self, cmp):
         try:
             return Version(self).__ge__(self._convert_to_version(cmp))
@@ -74,4 +74,12 @@ class TorchVersion(str):
             # version like 'parrot'
             return super().__ge__(cmp)
 
+    def __le__(self, cmp):
+        try:
+            return Version(self).__le__(self._convert_to_version(cmp))
+        except InvalidVersion:
+            # Fall back to regular string comparison if dealing with an invalid
+            # version like 'parrot'
+            return super().__le__(cmp)
+
 __version__ = TorchVersion(internal_version)