From 8deaa476ac33daea6af4cf079eb13757ca86c2ad Mon Sep 17 00:00:00 2001 From: Nived P A <62166124+nivedwho@users.noreply.github.com> Date: Thu, 9 Sep 2021 10:29:10 -0700 Subject: [PATCH] Added more version comparison operations (#63848) Summary: Currently the [TorchVersion](https://github.com/pytorch/pytorch/blob/1022443168b5fad55bbd03d087abf574c9d2e9df/torch/torch_version.py#L13) only only supports 'greater than', and 'equal to' operations for comparing torch versions and something like `TorchVersion('1.5.0') < (1,5,1)` or `TorchVersion('1.5.0') >= (1,5)` will throw an error. I have added 'less than' (`__lt__()`), 'greater than or equal to' (`__ge__()`) and 'less than or equal to' (`__le__()`) operations, so that the TorchVersion object can be useful for wider range of version comparisons. cc seemethere zsol Pull Request resolved: https://github.com/pytorch/pytorch/pull/63848 Reviewed By: fmassa, heitorschueroff Differential Revision: D30526996 Pulled By: seemethere fbshipit-source-id: 1db6bee555043e0719fd541cec27810852590940 --- torch/torch_version.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/torch/torch_version.py b/torch/torch_version.py index 4998c55..f840550 100644 --- a/torch/torch_version.py +++ b/torch/torch_version.py @@ -1,4 +1,3 @@ -from functools import total_ordering from typing import Iterable, Union from pkg_resources import packaging # type: ignore[attr-defined] @@ -9,24 +8,19 @@ InvalidVersion = packaging.version.InvalidVersion from .version import __version__ as internal_version -@total_ordering class TorchVersion(str): """A string with magic powers to compare to both Version and iterables! - Prior to 1.10.0 torch.__version__ was stored as a str and so many did comparisons against torch.__version__ as if it were a str. In order to not break them we have TorchVersion which masquerades as a str while also having the ability to compare against both packaging.version.Version as well as tuples of values, eg. (1, 2, 1) - Examples: Comparing a TorchVersion object to a Version object TorchVersion('1.10.0a') > Version('1.10.0a') - Comparing a TorchVersion object to a Tuple object TorchVersion('1.10.0a') > (1, 2) # 1.2 TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1 - Comparing a TorchVersion object against a string TorchVersion('1.10.0a') > '1.2' TorchVersion('1.10.0a') > '1.2.1' @@ -56,6 +50,13 @@ class TorchVersion(str): # version like 'parrot' return super().__gt__(cmp) + def __lt__(self, cmp): + try: + return Version(self).__lt__(self._convert_to_version(cmp)) + except InvalidVersion: + # Fall back to regular string comparison if dealing with an invalid + # version like 'parrot' + return super().__lt__(cmp) def __eq__(self, cmp): try: @@ -65,7 +66,6 @@ class TorchVersion(str): # version like 'parrot' return super().__eq__(cmp) - def __ge__(self, cmp): try: return Version(self).__ge__(self._convert_to_version(cmp)) @@ -74,4 +74,12 @@ class TorchVersion(str): # version like 'parrot' return super().__ge__(cmp) + def __le__(self, cmp): + try: + return Version(self).__le__(self._convert_to_version(cmp)) + except InvalidVersion: + # Fall back to regular string comparison if dealing with an invalid + # version like 'parrot' + return super().__le__(cmp) + __version__ = TorchVersion(internal_version) -- 2.7.4