--- /dev/null
+void dummy(int) { }
--- /dev/null
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CppExtension
+
+setup(
+ name="no_python_abi_suffix_test",
+ ext_modules=[
+ CppExtension("no_python_abi_suffix_test", ["no_python_abi_suffix_test.cpp"])
+ ],
+ cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
+)
import sys
import torch.cuda
from setuptools import setup
-from torch.utils.cpp_extension import CppExtension, CUDAExtension
+from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME
CXX_FLAGS = [] if sys.platform == 'win32' else ['-g', '-Werror']
name='torch_test_cpp_extension',
packages=['torch_test_cpp_extension'],
ext_modules=ext_modules,
- cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension})
+ cmdclass={'build_ext': BuildExtension})
import torch
-SHARED_LIBRARY_NAMES = {
- 'linux': 'libcustom_ops.so',
- 'darwin': 'libcustom_ops.dylib',
- 'win32': 'custom_ops.dll'
-}
-
def get_custom_op_library_path():
- path = os.path.abspath('build/{}'.format(
- SHARED_LIBRARY_NAMES[sys.platform]))
+ if sys.platform.startswith("win32"):
+ library_filename = "custom_ops.dll"
+ elif sys.platform.startswith("darwin"):
+ library_filename = "libcustom_ops.dylib"
+ else:
+ library_filename = "libcustom_ops.so"
+ path = os.path.abspath("build/{}".format(library_filename))
assert os.path.exists(path), path
return path
def main():
parser = argparse.ArgumentParser(
- description="Serialize a script module with custom ops")
+ description="Serialize a script module with custom ops"
+ )
parser.add_argument("--export-script-module-to", required=True)
options = parser.parse_args()
model.save(options.export_script_module_to)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
except RuntimeError:
print(CPP_EXTENSIONS_ERROR)
return 1
+ cpp_extensions_test_dir = os.path.join(test_directory, 'cpp_extensions')
return_code = shell([sys.executable, 'setup.py', 'install', '--root', './install'],
- os.path.join(test_directory, 'cpp_extensions'))
+ cwd=cpp_extensions_test_dir)
if return_code != 0:
return return_code
+ if sys.platform != 'win32':
+ return_code = shell([sys.executable, 'setup.py', 'install', '--root', './install'],
+ cwd=os.path.join(cpp_extensions_test_dir, 'no_python_abi_suffix_test'))
+ if return_code != 0:
+ return return_code
python_path = os.environ.get('PYTHONPATH', '')
try:
class TestCppExtension(common.TestCase):
def setUp(self):
- if sys.platform != 'win32':
+ if not IS_WINDOWS:
default_build_root = torch.utils.cpp_extension.get_default_build_root()
if os.path.exists(default_build_root):
shutil.rmtree(default_build_root)
@unittest.skipIf(not TEST_CUDNN, "CuDNN not found")
def test_jit_cudnn_extension(self):
# implementation of CuDNN ReLU
- if sys.platform == 'win32':
+ if IS_WINDOWS:
extra_ldflags = ['cudnn.lib']
else:
extra_ldflags = ['-lcudnn']
self.assertEqual(len(net.parameters()), 4)
p = net.named_parameters()
- self.assertEqual(type(p), dict)
self.assertEqual(len(p), 4)
self.assertIn('fc.weight', p)
self.assertIn('fc.bias', p)
is_python_module=False)
self.assertEqual(torch.ops.test.func(torch.eye(5)), torch.eye(5))
+ @unittest.skipIf(IS_WINDOWS, "Not available on Windows")
+ def test_no_python_abi_suffix_sets_the_correct_library_name(self):
+ # For this test, run_test.py will call `python setup.py install` in the
+ # cpp_extensions/no_python_abi_suffix_test folder, where the
+ # `BuildExtension` class has a `no_python_abi_suffix` option set to
+ # `True`. This *should* mean that on Python 3, the produced shared
+ # library does not have an ABI suffix like
+ # "cpython-37m-x86_64-linux-gnu" before the library suffix, e.g. "so".
+ # On Python 2 there is no ABI suffix anyway.
+ root = os.path.join("cpp_extensions", "no_python_abi_suffix_test", "build")
+ print(list(os.walk(os.path.join("cpp_extensions", "no_python_abi_suffix_test"))))
+ matches = [f for _, _, fs in os.walk(root) for f in fs if f.endswith("so")]
+ self.assertEqual(len(matches), 1, str(matches))
+ self.assertEqual(matches[0], "no_python_abi_suffix_test.so", str(matches))
+
if __name__ == '__main__':
common.run_tests()
!! WARNING !!
'''
-ACCEPTED_COMPILERS_FOR_PLATFORM = {'darwin': ['clang++', 'clang'], 'linux': ['g++', 'gcc']}
CUDA_HOME = _find_cuda_home()
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
# PyTorch releases have the version pattern major.minor.patch, whereas when
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
+def _accepted_compilers_for_platform():
+ return ['clang++', 'clang'] if sys.platform.startswith('darwin') else ['g++', 'gcc']
+
+
def get_default_build_root():
'''
Returns the path to the root folder under which extensions will built.
which = subprocess.check_output(['which', compiler], stderr=subprocess.STDOUT)
# Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'.
compiler_path = os.path.realpath(which.decode().strip())
- accepted_compilers = ACCEPTED_COMPILERS_FOR_PLATFORM[sys.platform]
- return any(name in compiler_path for name in accepted_compilers)
+ return any(name in compiler_path for name in _accepted_compilers_for_platform())
def check_compiler_abi_compatibility(compiler):
if not check_compiler_ok_for_platform(compiler):
warnings.warn(WRONG_COMPILER_WARNING.format(
user_compiler=compiler,
- pytorch_compiler=ACCEPTED_COMPILERS_FOR_PLATFORM[sys.platform][0],
+ pytorch_compiler=_accepted_compilers_for_platform()[0],
platform=sys.platform))
return False
- if sys.platform == 'darwin':
+ if sys.platform.startswith('darwin'):
# There is no particular minimum version we need for clang, so we're good here.
return True
try:
- if sys.platform == 'linux':
+ if sys.platform.startswith('linux'):
minimum_required_version = MINIMUM_GCC_VERSION
version = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
version = version.decode().strip().split('.')
return False
-class BuildExtension(build_ext):
+# See below for why we inherit BuildExtension from object.
+# https://stackoverflow.com/questions/1713038/super-fails-with-error-typeerror-argument-1-must-be-type-not-classobj-when
+
+
+class BuildExtension(build_ext, object):
'''
A custom :mod:`setuptools` build extension .
the C++ and CUDA compiler during mixed compilation.
'''
+ @classmethod
+ def with_options(cls, **options):
+ '''
+ Returns an alternative constructor that extends any original keyword
+ arguments to the original constructor with the given options.
+ '''
+ def init_with_options(*args, **kwargs):
+ kwargs = kwargs.copy()
+ kwargs.update(options)
+ return cls(*args, **kwargs)
+ return init_with_options
+
+ def __init__(self, *args, **kwargs):
+ super(BuildExtension, self).__init__(*args, **kwargs)
+ self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False)
+
def build_extensions(self):
self._check_abi()
for extension in self.extensions:
extra_postargs = None
def spawn(cmd):
- orig_cmd = cmd
# Using regex to match src, obj and include files
-
src_regex = re.compile('/T(p|c)(.*)')
src_list = [
m.group(2) for m in (src_regex.match(elem) for elem in cmd)
build_ext.build_extensions(self)
+ def get_ext_filename(self, ext_name):
+ # Get the original shared library name. For Python 3, this name will be
+ # suffixed with "<SOABI>.so", where <SOABI> will be something like
+ # cpython-37m-x86_64-linux-gnu. On Python 2, there is no such ABI name.
+ # The final extension, .so, would be .lib/.dll on Windows of course.
+ ext_filename = super(BuildExtension, self).get_ext_filename(ext_name)
+ # If `no_python_abi_suffix` is `True`, we omit the Python 3 ABI
+ # component. This makes building shared libraries with setuptools that
+ # aren't Python modules nicer.
+ if self.no_python_abi_suffix and sys.version_info >= (3, 0):
+ # The parts will be e.g. ["my_extension", "cpython-37m-x86_64-linux-gnu", "so"].
+ ext_filename_parts = ext_filename.split('.')
+ # Omit the second to last element.
+ without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
+ ext_filename = '.'.join(without_abi)
+ return ext_filename
+
def _check_abi(self):
# On some platforms, like Windows, compiler_cxx is not available.
if hasattr(self.compiler, 'compiler_cxx'):
else:
ldflags = ['-shared'] + extra_ldflags
# The darwin linker needs explicit consent to ignore unresolved symbols.
- if sys.platform == 'darwin':
+ if sys.platform.startswith('darwin'):
ldflags.append('-undefined dynamic_lookup')
elif IS_WINDOWS:
ldflags = _nt_quote_args(ldflags)