From b7106429697bf6a9e374d68aef839cbcfdfd3bf6 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Tue, 11 Dec 2018 19:11:02 -0800 Subject: [PATCH] Make ATen HIPify out-of-place, but still reuse CUDA names. (#14866) Summary: ``` This diff changes the HIPification of ATen to be out-of-place. We now have the following mappings: - ATen/cuda => ATen/hip - ATen/native/cuda => ATen/native/hip - ATen/native/sparse/cuda => ATen/native/sparse/hip - THC => THH - THCUNN => THHUNN The build system is adjusted to know about these new build paths, and HIPify is taught how to adjust include paths and THC_GENERIC_FILE appropriately. ATen_hip is now built as the ATen_hip library, rather than reusing ATen_cuda. However, despite these new filepaths, none of the identifiers in ATen have actually changed. So, e.g., THHGeneral.h still defines functions named THC_blahblah, and HIP still shows up as CUDA in PyTorch itself. We'll tackle this in a subsequent PR; this diff is just to get the files out-of-place. Minor extra improvements: - Don't edit tmp_install when hipifying - HIP no longer builds native_cudnn_cpp; it was unnecessary - Caffe2_HIP_INCLUDES is now Caffe2_HIP_INCLUDE, for consistency with all the other variables. - HIP build now properly respects ATEN_CUDA_FILES_GEN_LIB (it did not previously.) - You can now override file extension matching in pyHIPIFY by explicitly specifying its full name in the matching list. This is used so we can HIPify CMakeLists.txt in some situations. A little bit of string and ceiling wax: - gen.py grows a --rocm flag so that it knows to generate CUDA files which actually refer to the HIP headers (e.g., THH.h) We'll get rid of this eventually and generate real HIP files, but not for this PR. - Management of HIP dependencies is now completely deleted from the ATen CMakeLists.txt. The old code was dead (because it was shoveled in ATen_CUDA_DEPENDENCY_LIBS and promptly ignored by the Caffe2 build system) and didn't actually work. ``` Stacked on https://github.com/pytorch/pytorch/pull/14849 review last commit only Pull Request resolved: https://github.com/pytorch/pytorch/pull/14866 Differential Revision: D13419475 Pulled By: ezyang fbshipit-source-id: cb4c843df69a1d8369314c9fab1b7719520fa3db --- aten/CMakeLists.txt | 15 ++- aten/src/ATen/CMakeLists.txt | 125 +++++++++++++++++------ aten/src/ATen/gen.py | 39 +++++-- aten/src/ATen/miopen/Utils.h | 2 +- aten/src/ATen/native/miopen/BatchNorm_miopen.cpp | 2 + aten/src/ATen/native/miopen/Conv_miopen.cpp | 5 +- caffe2/CMakeLists.txt | 35 ++++--- cmake/Codegen.cmake | 6 ++ cmake/Dependencies.cmake | 9 +- cmake/public/utils.cmake | 2 +- third_party/sleef | 2 +- tools/amd_build/build_amd.py | 11 +- tools/amd_build/pyHIPIFY/hipify_python.py | 51 +++++++-- tools/cwrap/plugins/NNExtension.py | 4 +- torch/csrc/autograd/engine.cpp | 9 +- 15 files changed, 238 insertions(+), 79 deletions(-) diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index e6b56e3..25c8d04 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -21,9 +21,14 @@ set(ATen_THIRD_PARTY_INCLUDE) set(ATen_CUDA_SRCS) set(ATen_CUDA_TEST_SRCS) set(ATen_CUDA_INCLUDE) +set(ATen_HIP_SRCS) +set(ATen_HIP_TEST_SRCS) +set(ATen_HIP_INCLUDE) set(ATen_CPU_DEPENDENCY_LIBS) set(ATen_CUDA_DEPENDENCY_LIBS) +set(ATen_HIP_DEPENDENCY_LIBS) set(ATen_PUBLIC_CUDA_DEPENDENCY_LIBS) +set(ATen_PUBLIC_HIP_DEPENDENCY_LIBS) SET(ATEN_INSTALL_BIN_SUBDIR "bin" CACHE PATH "ATen install binary subdirectory") SET(ATEN_INSTALL_LIB_SUBDIR "lib" CACHE PATH "ATen install library subdirectory") SET(ATEN_INSTALL_INCLUDE_SUBDIR "include" CACHE PATH "ATen install include subdirectory") @@ -58,9 +63,11 @@ IF(MSVC) ENDIF(MSVC) if(USE_ROCM) + # TODO: AT_HIP_ENABLED (change this once we represent HIP as HIP in + # ATen proper) SET(AT_CUDA_ENABLED 1) - add_subdirectory(src/THC) - add_subdirectory(src/THCUNN) + add_subdirectory(src/THH) + add_subdirectory(src/THHUNN) message("ROCm is enabled.") elseif(USE_CUDA) SET(AT_CUDA_ENABLED 1) @@ -79,11 +86,15 @@ add_subdirectory(src/ATen) # Pass source, includes, and libs to parent set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE) set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} PARENT_SCOPE) +set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE) set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE) +set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE) set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE) set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE) +set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE) set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE) set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE) +set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 584ad22..8edb2ebc 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -17,7 +17,14 @@ IF(NOT AT_INSTALL_BIN_DIR OR NOT AT_INSTALL_LIB_DIR OR NOT AT_INSTALL_INCLUDE_DI ENDIF() CONFIGURE_FILE(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h") -CONFIGURE_FILE(cuda/CUDAConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/cuda/CUDAConfig.h") +# TODO: Don't unconditionally generate CUDAConfig.h.in. Unfortuantely, +# this file generates AT_ROCM_ENABLED() which is required by the miopen +# files, which are compiled even if we are doing a vanilla CUDA build. +# Once we properly split CUDA and HIP in ATen, we can remove this code. +configure_file(cuda/CUDAConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/cuda/CUDAConfig.h") +if(USE_ROCM) + configure_file(hip/HIPConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/hip/HIPConfig.h") +endif() # NB: If you edit these globs, you'll have to update setup.py package_data as well FILE(GLOB base_h "*.h" "detail/*.h" "cpu/*.h") @@ -28,21 +35,32 @@ FILE(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp") FILE(GLOB cuda_cu "cuda/*.cu" "cuda/detail/*.cu") FILE(GLOB cudnn_h "cudnn/*.h" "cudnn/*.cuh") FILE(GLOB cudnn_cpp "cudnn/*.cpp") + +FILE(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh") +FILE(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp") +FILE(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip") FILE(GLOB miopen_h "miopen/*.h") FILE(GLOB miopen_cpp "miopen/*.cpp") + FILE(GLOB mkl_cpp "mkl/*.cpp") FILE(GLOB mkldnn_cpp "mkldnn/*.cpp") FILE(GLOB native_cpp "native/*.cpp") +FILE(GLOB native_mkl_cpp "native/mkl/*.cpp") +FILE(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp") FILE(GLOB native_sparse_cpp "native/sparse/*.cpp") + +FILE(GLOB native_cuda_cu "native/cuda/*.cu") +FILE(GLOB native_cuda_cpp "native/cuda/*.cpp") +FILE(GLOB native_cudnn_cpp "native/cudnn/*.cpp") FILE(GLOB native_sparse_cuda_cu "native/sparse/cuda/*.cu") FILE(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp") -FILE(GLOB native_cudnn_cpp "native/cudnn/*.cpp") + +FILE(GLOB native_hip_hip "native/hip/*.hip") +FILE(GLOB native_hip_cpp "native/hip/*.cpp") FILE(GLOB native_miopen_cpp "native/miopen/*.cpp") -FILE(GLOB native_cuda_cu "native/cuda/*.cu") -FILE(GLOB native_cuda_cpp "native/cuda/*.cpp") -FILE(GLOB native_mkl_cpp "native/mkl/*.cpp") -FILE(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp") +FILE(GLOB native_sparse_hip_hip "native/sparse/hip/*.hip") +FILE(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp") set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${cpu_kernel_cpp}) if(AT_MKL_ENABLED) @@ -52,22 +70,32 @@ if(AT_MKLDNN_ENABLED) set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp}) endif() -IF(USE_CUDA OR USE_ROCM) +if(USE_CUDA AND USE_ROCM) + message(FATAL_ERROR "ATen doesn't not currently support simultaneously building with CUDA and ROCM") +endif() + +IF(USE_CUDA) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/cuda) set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${cuda_cu} ${native_cuda_cu} ${native_sparse_cuda_cu}) set(all_cuda_cpp ${native_sparse_cuda_cpp} ${cuda_cpp} ${native_cuda_cpp} ${cuda_generated_cpp} ${ATen_CUDA_SRCS}) - IF(USE_CUDA) - SET(all_cuda_cpp ${native_cudnn_cpp} ${native_miopen_cpp} ${all_cuda_cpp}) - IF(CUDNN_FOUND) - SET(all_cuda_cpp ${all_cuda_cpp} ${cudnn_cpp}) - ENDIF() - ELSEIF(USE_ROCM) - SET(all_cuda_cpp ${native_cudnn_cpp} ${native_miopen_cpp} ${miopen_cpp} ${all_cuda_cpp}) + SET(all_cuda_cpp ${native_cudnn_cpp} ${native_miopen_cpp} ${all_cuda_cpp}) + IF(CUDNN_FOUND) + SET(all_cuda_cpp ${all_cuda_cpp} ${cudnn_cpp}) ENDIF() endif() +IF(USE_ROCM) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip) + set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_sparse_hip_hip}) + # TODO: Codegen separate files for HIP and use those (s/cuda_generated_cpp/hip_generated_cpp) + set(all_hip_cpp ${native_sparse_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${cuda_generated_cpp} ${ATen_HIP_SRCS}) + set(all_hip_cpp ${native_miopen_cpp} ${miopen_cpp} ${all_hip_cpp}) +endif() + filter_list(generated_h generated_cpp "\\.h$") filter_list(cuda_generated_h cuda_generated_cpp "\\.h$") +# TODO: When we have hip_generated_cpp +#filter_list(hip_generated_h hip_generated_cpp "\\.h$") list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..) # so the build can find the generated header files @@ -81,21 +109,28 @@ IF(BLAS_FOUND) MESSAGE(STATUS "TH_BINARY_BUILD detected. Enabling special linkage.") list(APPEND ATen_CPU_DEPENDENCY_LIBS "${BLAS_LIBRARIES};${BLAS_LIBRARIES};${BLAS_LIBRARIES}") - if(USE_CUDA OR USE_ROCM) + if(USE_CUDA) list(APPEND ATen_CUDA_DEPENDENCY_LIBS "${BLAS_LIBRARIES};${BLAS_LIBRARIES};${BLAS_LIBRARIES}") endif() + if(USE_ROCM) + list(APPEND ATen_HIP_DEPENDENCY_LIBS + "${BLAS_LIBRARIES};${BLAS_LIBRARIES};${BLAS_LIBRARIES}") + endif() ELSE ($ENV{TH_BINARY_BUILD}) list(APPEND ATen_CPU_DEPENDENCY_LIBS ${BLAS_LIBRARIES}) - if(USE_CUDA OR USE_ROCM) + if(USE_CUDA) list(APPEND ATen_CUDA_DEPENDENCY_LIBS "${BLAS_LIBRARIES}") endif() + if(USE_ROCM) + list(APPEND ATen_HIP_DEPENDENCY_LIBS "${BLAS_LIBRARIES}") + endif() ENDIF ($ENV{TH_BINARY_BUILD}) ENDIF(BLAS_FOUND) IF(LAPACK_FOUND) list(APPEND ATen_CPU_DEPENDENCY_LIBS ${LAPACK_LIBRARIES}) - if(USE_CUDA OR USE_ROCM) + if(USE_CUDA) # Although Lapack provides CPU (and thus, one might expect that ATen_cuda # would not need this at all), some of our libraries (magma in particular) # backend to CPU BLAS/LAPACK implementations, and so it is very important @@ -104,6 +139,11 @@ IF(LAPACK_FOUND) # This caused https://github.com/pytorch/pytorch/issues/7353 list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${LAPACK_LIBRARIES}) endif() + if(USE_ROCM) + # It's not altogether clear that HIP behaves the same way, but it + # seems safer to assume that it needs it too + list(APPEND ATen_HIP_DEPENDENCY_LIBS ${LAPACK_LIBRARIES}) + endif() ENDIF(LAPACK_FOUND) IF (UNIX AND NOT APPLE) @@ -258,22 +298,21 @@ IF(USE_CUDA AND NOT USE_ROCM) ENDIF($ENV{ATEN_STATIC_CUDA}) ENDIF() -IF(USE_ROCM) - ### Link in the ROCm libraries BLAS / RNG. - FIND_LIBRARY(ROCBLAS_LIBRARY rocblas HINTS ${ROCBLAS_PATH}/lib) - FIND_LIBRARY(HIPRAND_LIBRARY hiprand HINTS ${HIPRAND_PATH}/lib) +# NB: We're relying on cmake/Dependencies.cmake to appropriately setup HIP dependencies. +# In principle we could duplicate them, but handling the rocblas +# dependency is nontrivial. So better not to copy-paste. +# Look for Note [rocblas cmake bug] - list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${ROCBLAS_LIBRARY} ${HIPRAND_LIBRARY}) -ENDIF() - -# Include CPU paths for CUDA as well +# Include CPU paths for CUDA/HIP as well list(APPEND ATen_CUDA_INCLUDE ${ATen_CPU_INCLUDE}) +list(APPEND ATen_HIP_INCLUDE ${ATen_CPU_INCLUDE}) # We have two libraries: libATen_cpu.so and libATen_cuda.so, # with libATen_cuda.so depending on libATen_cpu.so. The CPU library # contains CPU code only. libATen_cpu.so is invariant to the setting # of USE_CUDA (it always builds the same way); libATen_cuda.so is only -# built when USE_CUDA=1 and CUDA is available. +# built when USE_CUDA=1 and CUDA is available. (libATen_hip.so works +# the same way as libATen_cuda.so) set(ATen_CPU_SRCS ${all_cpu_cpp}) if(AT_LINK_STYLE STREQUAL "INTERFACE") # Source code can't be added to an interface library, so it is @@ -297,7 +336,7 @@ else() set(ATen_CPU_SRCS) endif() -if(USE_CUDA OR USE_ROCM) +if(USE_CUDA) set(ATen_CUDA_SRCS ${all_cuda_cpp}) if(AT_LINK_STYLE STREQUAL "INTERFACE") # Source code can't be added to an interface library, so it is @@ -309,6 +348,21 @@ if(USE_CUDA OR USE_ROCM) endif() endif() +if(USE_ROCM) + set(ATen_HIP_SRCS ${all_hip_cpp}) + if(AT_LINK_STYLE STREQUAL "INTERFACE") + # Source code can't be added to an interface library, so it is + # passed back to be compiled into the containing library + add_library(ATen_hip INTERFACE) + # NB: Instead of adding it to this list, we add it by hand + # to caffe2_hip, because it needs to be a PRIVATE dependency + # list(APPEND ATen_HIP_DEPENDENCY_LIBS ATEN_CUDA_FILES_GEN_LIB) + else() + message(FATAL_ERROR "Non-INTERFACE AT_LINK_STYLE not (yet) supported for ROCm build") + endif() +endif() + + if(NOT AT_LINK_STYLE STREQUAL "INTERFACE") if(USE_CUDA) if (NOT $ENV{ATEN_STATIC_CUDA}) @@ -319,16 +373,22 @@ if(NOT AT_LINK_STYLE STREQUAL "INTERFACE") if(NOT MSVC) torch_compile_options(ATen_cpu) - if(USE_CUDA OR USE_ROCM) + if(USE_CUDA) torch_compile_options(ATen_cuda) endif() + if(USE_ROCM) + torch_compile_options(ATen_hip) + endif() endif() if(NOT ${CMAKE_VERSION} VERSION_LESS "3.1") set_property(TARGET ATen_cpu PROPERTY CXX_STANDARD 11) - if(USE_CUDA OR USE_ROCM) + if(USE_CUDA) set_property(TARGET ATen_cuda PROPERTY CXX_STANDARD 11) endif() + if(USE_ROCM) + set_property(TARGET ATen_hip PROPERTY CXX_STANDARD 11) + endif() endif() endif() @@ -338,11 +398,12 @@ INSTALL(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" DESTINATION "${AT_INSTALL_SHARE_DIR}/cmake/ATen") # https://stackoverflow.com/questions/11096471/how-can-i-install-a-hierarchy-of-files-using-cmake -FOREACH(HEADER ${base_h} ${ATen_CORE_HEADERS} ${cuda_h} ${cudnn_h}) +FOREACH(HEADER ${base_h} ${ATen_CORE_HEADERS} ${cuda_h} ${cudnn_h} ${hip_h} ${miopen_h}) string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/" "" HEADER_SUB ${HEADER}) GET_FILENAME_COMPONENT(DIR ${HEADER_SUB} DIRECTORY) INSTALL(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen/${DIR}) ENDFOREACH() +# TODO: Install hip_generated_h when we have it FOREACH(HEADER ${generated_h} ${cuda_generated_h}) # NB: Assumed to be flat INSTALL(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen) @@ -360,11 +421,15 @@ endif() set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE) set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE) set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} PARENT_SCOPE) +set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE) set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE) set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE) +set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE) set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE) set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE) set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE) +set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE) set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE) +set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE) diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index a4e6be5..9a60112 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -38,6 +38,10 @@ parser.add_argument( help='output a list of dependencies into the given file and exit') parser.add_argument( '-d', '--install_dir', help='output directory', default='ATen') +parser.add_argument( + '--rocm', + action='store_true', + help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly') options = parser.parse_args() gen_to_source = os.environ.get('GEN_TO_SOURCE') # update source directly as part of gen if not gen_to_source: @@ -154,7 +158,7 @@ generators = { 'CUDAGenerator.h': { 'name': 'CUDA', 'th_generator': '', - 'header': 'THC/THC.h' + 'header': 'THC/THC.h' if not options.rocm else 'THH/THH.h' }, } @@ -259,17 +263,30 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations top_env['type_ids'].append(tag + ',') if backend == 'CUDA': - env['th_headers'] = [ - '#include ', - '#include ', - '#include ', - '#undef THNN_', - '#undef THCIndexTensor_', - ] - env['extra_cuda_headers'] = ['#include '] + env['extra_cuda_headers'] = [] env['extra_cuda_headers'].append('#include ') - env['extra_cuda_headers'].append('#include ') - env['extra_cuda_headers'].append('#include ') + if options.rocm: + env['th_headers'] = [ + '#include ', + '#include ', + '#include ', + '#undef THNN_', + '#undef THCIndexTensor_', + ] + env['extra_cuda_headers'].append('#include ') + env['extra_cuda_headers'].append('#include ') + env['extra_cuda_headers'].append('#include ') + else: + env['th_headers'] = [ + '#include ', + '#include ', + '#include ', + '#undef THNN_', + '#undef THCIndexTensor_', + ] + env['extra_cuda_headers'].append('#include ') + env['extra_cuda_headers'].append('#include ') + env['extra_cuda_headers'].append('#include ') sname = '' if scalar_name == "Float" else scalar_name env['THType'] = 'Cuda{}'.format(sname) env['THStorage'] = 'THCuda{}Storage'.format(sname) diff --git a/aten/src/ATen/miopen/Utils.h b/aten/src/ATen/miopen/Utils.h index 58bfa9f..c264650 100644 --- a/aten/src/ATen/miopen/Utils.h +++ b/aten/src/ATen/miopen/Utils.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include #include diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index 4a3834f..8e5110c 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -2,6 +2,8 @@ #include #include +// TODO: Remove the condition on AT_ROCM_ENABLED entirely, +// don't build this file as part of CPU build. #include #if !AT_ROCM_ENABLED() diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index f931ee6..3f8651c 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -1,6 +1,9 @@ #include #include #include + +// TODO: Remove the condition on AT_ROCM_ENABLED entirely, +// don't build this file as part of CPU build. #include #if !AT_ROCM_ENABLED() @@ -74,7 +77,7 @@ std::tuple miopen_convolution_transpose_backwa #else // AT_ROCM_ENABLED -#include +#include #include #include diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index f4b3db1..2b110e8 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -32,20 +32,18 @@ if (NOT BUILD_ATEN_MOBILE) # Add source, includes, and libs to lists list(APPEND Caffe2_CPU_SRCS ${ATen_CPU_SRCS}) list(APPEND Caffe2_GPU_SRCS ${ATen_CUDA_SRCS}) + list(APPEND Caffe2_HIP_SRCS ${ATen_HIP_SRCS}) list(APPEND Caffe2_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS}) list(APPEND Caffe2_GPU_TEST_SRCS ${ATen_CUDA_TEST_SRCS}) + list(APPEND Caffe2_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS}) list(APPEND Caffe2_CPU_TEST_SRCS ${ATen_CORE_TEST_SRCS}) list(APPEND Caffe2_CPU_INCLUDE ${ATen_CPU_INCLUDE}) list(APPEND Caffe2_GPU_INCLUDE ${ATen_CUDA_INCLUDE}) + list(APPEND Caffe2_HIP_INCLUDE ${ATen_HIP_INCLUDE}) list(APPEND Caffe2_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS}) list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS}) + list(APPEND Caffe2_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS}) list(APPEND Caffe2_DEPENDENCY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE}) - - IF(USE_ROCM) - # Set the HIP Variables - set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS} ${ATen_CUDA_SRCS}) - set(Caffe2_HIP_INCLUDES ${Caffe2_HIP_INCLUDES} ${Caffe2_GPU_INCLUDE}) - ENDIF(USE_ROCM) else() # Only add "ATen Core", a minimal, easy-to-compile fragment of ATen. # This codepath should only be exercised by the Android build. @@ -168,6 +166,11 @@ if (FALSE) foreach(tmp ${ATen_CUDA_TEST_SRCS}) message(STATUS " " ${tmp}) endforeach() + + message(STATUS "ATen HIP test sources: ") + foreach(tmp ${ATen_HIP_TEST_SRCS}) + message(STATUS " " ${tmp}) + endforeach() endif() # ---[ List of libraries to link with @@ -369,7 +372,7 @@ endif() # ---[ Caffe2 HIP sources. if(USE_ROCM) - # Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs. + # Call again since Caffe2_HIP_INCLUDE is extended with ATen include dirs. # Get Compile Definitions from the directory (FindHIP.cmake bug) get_directory_property(MY_DEFINITIONS COMPILE_DEFINITIONS) if(MY_DEFINITIONS) @@ -378,8 +381,8 @@ if(USE_ROCM) endforeach() endif() - # Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs. - hip_include_directories(${Caffe2_HIP_INCLUDES}) + # Call again since Caffe2_HIP_INCLUDE is extended with ATen include dirs. + hip_include_directories(${Caffe2_HIP_INCLUDE}) filter_list(__caffe2_hip_srcs_cpp Caffe2_HIP_SRCS "\\.(cu|hip)$") set_source_files_properties(${__caffe2_hip_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) @@ -391,10 +394,18 @@ if(USE_ROCM) target_compile_options(caffe2_hip PRIVATE ${HIP_CXX_FLAGS}) target_link_libraries(caffe2_hip PUBLIC caffe2) target_link_libraries(caffe2_hip PUBLIC c10_hip) + if(NOT BUILD_ATEN_MOBILE) + # TODO: Cut this over to ATEN_HIP_FILES_GEN_LIB. At the moment, we + # only generate CUDA files + # NB: This dependency must be PRIVATE, because we don't install + # ATEN_CUDA_FILES_GEN_LIB (it's a synthetic target just to get the + # correct dependency from generated files.) + target_link_libraries(caffe2_hip PRIVATE ATEN_CUDA_FILES_GEN_LIB) + endif() target_link_libraries(caffe2_hip PUBLIC ${Caffe2_HIP_DEPENDENCY_LIBS}) # Since PyTorch files contain HIP headers, this is also needed to capture the includes. - target_include_directories(caffe2_hip PRIVATE ${Caffe2_HIP_INCLUDES}) + target_include_directories(caffe2_hip PRIVATE ${Caffe2_HIP_INCLUDE}) target_include_directories(caffe2_hip INTERFACE $) # Set standard properties on the target @@ -447,7 +458,7 @@ if (BUILD_TEST) add_executable(${test_name} "${test_src}") target_link_libraries(${test_name} ${Caffe2_MAIN_LIBS} gtest_main) target_include_directories(${test_name} PRIVATE $) - target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE} ${Caffe2_HIP_INCLUDES}) + target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE} ${Caffe2_HIP_INCLUDE}) target_compile_options(${test_name} PRIVATE ${HIP_CXX_FLAGS}) add_test(NAME ${test_name} COMMAND $) if (INSTALL_TEST) @@ -583,7 +594,7 @@ if (BUILD_PYTHON) set_target_properties(caffe2_pybind11_state_hip PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") endif() target_include_directories(caffe2_pybind11_state_hip PRIVATE $) - target_include_directories(caffe2_pybind11_state_hip PRIVATE ${Caffe2_CPU_INCLUDE} ${Caffe2_HIP_INCLUDES}) + target_include_directories(caffe2_pybind11_state_hip PRIVATE ${Caffe2_CPU_INCLUDE} ${Caffe2_HIP_INCLUDE}) target_link_libraries( caffe2_pybind11_state_hip caffe2_library caffe2_hip_library) if (WIN32) diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index ce04ca0..f496313 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -134,10 +134,16 @@ if (NOT BUILD_ATEN_MOBILE) FILE(GLOB all_python "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/*.py") + set(GEN_ROCM_FLAG) + if (USE_ROCM) + set(GEN_ROCM_FLAG --rocm) + endif() + SET(GEN_COMMAND ${PYCMD} ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen.py --source-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen --install_dir ${CMAKE_BINARY_DIR}/aten/src/ATen + ${GEN_ROCM_FLAG} ${cwrap_files} ) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 1b84bf1..a0853a4 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -753,11 +753,11 @@ if(USE_ROCM) list(APPEND HIP_HCC_FLAGS -fno-gpu-rdc) list(APPEND HIP_HCC_FLAGS -amdgpu-target=${HCC_AMDGPU_TARGET}) - set(Caffe2_HIP_INCLUDES - ${hip_INCLUDE_DIRS} ${hcc_INCLUDE_DIRS} ${hsa_INCLUDE_DIRS} ${rocrand_INCLUDE_DIRS} ${hiprand_INCLUDE_DIRS} ${rocblas_INCLUDE_DIRS} ${miopen_INCLUDE_DIRS} ${thrust_INCLUDE_DIRS} $ ${Caffe2_HIP_INCLUDES}) + set(Caffe2_HIP_INCLUDE + ${hip_INCLUDE_DIRS} ${hcc_INCLUDE_DIRS} ${hsa_INCLUDE_DIRS} ${rocrand_INCLUDE_DIRS} ${hiprand_INCLUDE_DIRS} ${rocblas_INCLUDE_DIRS} ${miopen_INCLUDE_DIRS} ${thrust_INCLUDE_DIRS} $ ${Caffe2_HIP_INCLUDE}) # This is needed for library added by hip_add_library (same for hip_add_executable) - hip_include_directories(${Caffe2_HIP_INCLUDES}) + hip_include_directories(${Caffe2_HIP_INCLUDE}) set(Caffe2_HIP_DEPENDENCY_LIBS ${rocrand_LIBRARIES} ${hiprand_LIBRARIES} ${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES}) @@ -765,7 +765,10 @@ if(USE_ROCM) if(NOT BUILD_ATEN_MOBILE) set(Caffe2_HIP_DEPENDENCY_LIBS ${Caffe2_HIP_DEPENDENCY_LIBS} ${hipsparse_LIBRARIES}) endif() + # Note [rocblas cmake bug] + # ~~~~~~~~~~~~~~~~~~~~~~~~ # TODO: There is a bug in rocblas's cmake files that exports the wrong targets name in ${rocblas_LIBRARIES} + # If you get this wrong, you'll get a complaint like 'ld: cannot find -lrocblas-targets' list(APPEND Caffe2_HIP_DEPENDENCY_LIBS roc::rocblas) diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index a464a10..7178374 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -124,7 +124,7 @@ function(caffe2_hip_binary_target target_name_or_src) caffe2_binary_target(${target_name_or_src}) target_compile_options(${__target} PRIVATE ${HIP_CXX_FLAGS}) - target_include_directories(${__target} PRIVATE ${Caffe2_HIP_INCLUDES}) + target_include_directories(${__target} PRIVATE ${Caffe2_HIP_INCLUDE}) endfunction() ############################################################################## diff --git a/third_party/sleef b/third_party/sleef index 191f655..6ff7a13 160000 --- a/third_party/sleef +++ b/third_party/sleef @@ -1 +1 @@ -Subproject commit 191f655caa25526ae226cf88dd2529265176014a +Subproject commit 6ff7a135a1e31979d1e1844a2e7171dfbd34f54f diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index eeda8d8..2927e36 100644 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -45,16 +45,23 @@ includes = [ "aten/src/THC/*", "aten/src/THCUNN/*", "aten/src/ATen/test/*", + # CMakeLists.txt isn't processed by default, but there are a few + # we do want to handle, so explicitly specify them + "aten/src/THC/CMakeLists.txt", + "aten/src/THCUNN/CMakeLists.txt", "torch/*", ] ignores = [ "caffe2/operators/depthwise_3x3_conv_op_cudnn.cu", "caffe2/operators/pool_op_cudnn.cu", - '**/hip/**', + '*/hip/*', # These files are compatible with both cuda and hip "aten/src/ATen/core/*", - "torch/csrc/autograd/engine.cpp" + "torch/csrc/autograd/engine.cpp", + # generated files we shouldn't frob + "torch/lib/tmp_install/*", + "torch/lib/include/*", ] json_settings = os.path.join(amd_build_dir, "disabled_features.json") diff --git a/tools/amd_build/pyHIPIFY/hipify_python.py b/tools/amd_build/pyHIPIFY/hipify_python.py index dbffceb..f3dde66 100755 --- a/tools/amd_build/pyHIPIFY/hipify_python.py +++ b/tools/amd_build/pyHIPIFY/hipify_python.py @@ -211,7 +211,9 @@ def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), ou def match_extensions(filename): """Helper method to see if filename ends with certain extension""" - return os.path.splitext(filename)[1] in extensions + return any(filename.endswith(e) for e in extensions) + + exact_matches = set(includes) # This is a very rough heuristic; really, we want to avoid scanning # any file which is not checked into source control, but this script @@ -230,7 +232,9 @@ def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), ou dirs.remove("third_party") for filename in filenames: filepath = os.path.join(rel_dirpath, filename) - if _fnmatch(filepath, includes) and (not _fnmatch(filepath, ignores)) and match_extensions(filepath): + # We respect extensions, UNLESS you wrote the entire + # filename verbatim, in which case we always accept it + if _fnmatch(filepath, includes) and (not _fnmatch(filepath, ignores)) and (match_extensions(filepath) or filepath in exact_matches): if not is_pytorch_file(filepath) and not is_caffe2_gpu_file(filepath): continue if out_of_place_only and not is_out_of_place(filepath): @@ -714,11 +718,9 @@ def get_hip_file_path(filepath): """ Returns the new name of the hipified file """ - # At the moment, PyTorch is HIPIFYed in-place. We'd prefer for this - # to not be the case, but we can't conveniently do this until we - # also fix up PyTorch's build system to know how to handle things - # out-of-place. - if is_pytorch_file(filepath): + # At the moment, some files are HIPified in place. The predicate + # is_out_of_place tells us if this is the case or not. + if not is_out_of_place(filepath): return filepath dirpath, filename = os.path.split(filepath) @@ -762,8 +764,13 @@ def get_hip_file_path(filepath): orig_dirpath = dirpath dirpath = dirpath.replace('cuda', 'hip') + dirpath = dirpath.replace('THC', 'THH') + root = root.replace('cuda', 'hip') root = root.replace('CUDA', 'HIP') + # Special case to handle caffe2/core/THCCachingAllocator + if dirpath != "caffe2/core": + root = root.replace('THC', 'THH') if dirpath == orig_dirpath: dirpath = os.path.join(dirpath, 'hip') @@ -772,7 +779,9 @@ def get_hip_file_path(filepath): def is_out_of_place(filepath): - return not is_pytorch_file(filepath) + if filepath.startswith("torch/"): + return False + return True # Keep this synchronized with includes/ignores in build_amd.py @@ -872,6 +881,11 @@ for mapping in CUDA_TO_HIP_MAPPINGS: RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern()) RE_PYTORCH_PREPROCESSOR = re.compile(r'\b{0}\b'.format(PYTORCH_TRIE.pattern())) +RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"') +RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>') +RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"') +RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh + def preprocessor(output_directory, filepath, stats): """ Executes the CUDA -> HIP conversion on the specified file. """ fin_path = os.path.join(output_directory, filepath) @@ -893,6 +907,24 @@ def preprocessor(output_directory, filepath, stats): return CAFFE2_MAP[m.group(0)] output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source) + # Header rewrites + def mk_repl(templ): + def repl(m): + f = m.group(1) + if f.startswith("ATen/cuda") or f.startswith("ATen/native/cuda") or f.startswith("ATen/native/sparse/cuda") or f.startswith("THC/") or f.startswith("THCUNN/") or (f.startswith("THC") and not f.startswith("THCP")): + return templ.format(get_hip_file_path(m.group(1))) + return m.group(0) + return repl + output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"'), output_source) + output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>'), output_source) + output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source) + + # CMakeLists.txt rewrites + if filepath.endswith('CMakeLists.txt'): + output_source = output_source.replace('CUDA', 'HIP') + output_source = output_source.replace('THC', 'THH') + output_source = RE_CU_SUFFIX.sub('.hip', output_source) + # Perform Kernel Launch Replacements output_source = processKernelLaunches(output_source, stats) @@ -1226,7 +1258,7 @@ def add_static_casts(orig_filepath, filepath, KernelTemplateParams): # PyTorch Specific: Add template type # Here the template value will be resolved from to . - if "THCUNN" in filepath.split("/") and "generic" not in filepath.split("/"): + if "THHUNN" in filepath.split("/") and "generic" not in filepath.split("/"): kernel_name_with_template = kernel_name_with_template.replace("", "") full_new_kernel_launch = re.sub(r'\b{0}\b'.format(re.escape(original_kernel_name_with_template)), @@ -1256,7 +1288,6 @@ def str2bool(v): raise argparse.ArgumentTypeError('Boolean value expected.') - def hipify( project_directory, show_detailed=False, diff --git a/tools/cwrap/plugins/NNExtension.py b/tools/cwrap/plugins/NNExtension.py index 47a37e0..b30a433 100644 --- a/tools/cwrap/plugins/NNExtension.py +++ b/tools/cwrap/plugins/NNExtension.py @@ -12,8 +12,10 @@ MODULE_HEAD = """ // HIPify isn't being applied to autogenerated files, so defensively // handle both the CUDA and ROCM cases. -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) #include +#elif defined(USE_ROCM) +#include #endif """ diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 6dfe642..564005c 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -27,16 +27,17 @@ #include #include -#if defined(USE_CUDA) || defined(USE_ROCM) #ifdef USE_CUDA #include +#include +#include #endif // USE_CUDA + #ifdef USE_ROCM #include +#include +#include #endif // USE_ROCM -#include -#include -#endif // defined(USE_CUDA) || defined(USE_ROCM) namespace torch { namespace autograd { -- 2.7.4