Improve HIPify performance (#14803)
authorEdward Yang <ezyang@fb.com>
Wed, 5 Dec 2018 18:57:00 +0000 (10:57 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 19:00:03 +0000 (11:00 -0800)
Summary:
```
    Improve performance of pyHIPIFY

    Changes:
    - Pre-compile regexes, don't use regexes when it's not necessary
      (this saves us ~15%)
    - Compile all substitutions for mappings into a single, non-backtracking
      regex using a Trie.  This gives big savings.

    Before, running pyHIPIFY on all files took 15.8s.  Now it takes 3.9s.
```

Stacked on #14769
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14803

Differential Revision: D13342620

Pulled By: ezyang

fbshipit-source-id: 1cfa36b3236bbe24d07080a31cc788a52d740f40

tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py
tools/amd_build/pyHIPIFY/hipify_python.py

index 817d13b..c530c50 100644 (file)
@@ -1905,6 +1905,8 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict([
     ("cublasCgemm3mEx", ("rocblas_cgemm_3mex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
     ("cublasZgemm_v2", ("rocblas_zgemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
     ("cublasZgemm3m", ("rocblas_zgemm_3m", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
+    # NB: The function rocblas_sgemmex doesn't actually exist in
+    # rocblas, as of 2018-12-05
     ("cublasSgemmEx", ("rocblas_sgemmex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
     ("cublasGemmEx", ("rocblas_gemmex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
     ("cublasCgemmEx", ("rocblas_cgemmex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)),
index 09c0566..9af02ce 100755 (executable)
@@ -321,10 +321,13 @@ def add_dim3(kernel_string, cuda_kernel):
     return cuda_kernel
 
 
+RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+')
+
+
 def processKernelLaunches(string, stats):
     """ Replace the CUDA style Kernel launches with the HIP style kernel launches."""
     # Concat the namespace with the kernel names. (Find cleaner way of doing this later).
-    string = re.sub(r'([ ]+)(detail?)::[ ]+\\\n[ ]+', lambda inp: "{0}{1}::".format(inp.group(1), inp.group(2)), string)
+    string = RE_KERNEL_LAUNCH.sub(lambda inp: "{0}{1}::".format(inp.group(1), inp.group(2)), string)
 
     def grab_method_and_template(in_kernel):
         # The positions for relevant kernel components.
@@ -487,12 +490,15 @@ def find_parentheses_group(input_string, start):
     return find_closure_group(input_string, start, group=["(", ")"])
 
 
+RE_ASSERT = re.compile(r"\bassert[ ]*\(")
+
+
 def disable_asserts(input_string):
     """ Disables regular assert statements
     e.g. "assert(....)" -> "/*assert(....)*/"
     """
     output_string = input_string
-    asserts = list(re.finditer(r"\bassert[ ]*\(", input_string))
+    asserts = list(RE_ASSERT.finditer(input_string))
     for assert_item in asserts:
         p_start, p_end = find_parentheses_group(input_string, assert_item.end() - 1)
         start = assert_item.start()
@@ -506,9 +512,7 @@ def replace_forceinline(input_string):
     replacing '__forceinline__' with 'inline' as a workaround
     https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_faq.md#what-if-hip-generates-error-of-symbol-multiply-defined-only-on-amd-machine
     """
-    output_string = input_string
-    output_string = re.sub("__forceinline__", "inline", output_string)
-    return output_string
+    return input_string.replace("__forceinline__", "inline")
 
 
 def replace_math_functions(input_string):
@@ -523,6 +527,9 @@ def replace_math_functions(input_string):
     return output_string
 
 
+RE_SYNCTHREADS = re.compile(r"[:]?[:]?\b(__syncthreads)\b(\w*\()")
+
+
 def hip_header_magic(input_string):
     """If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
     then automatically add an #include to match the "magic" includes provided by NVCC.
@@ -542,7 +549,7 @@ def hip_header_magic(input_string):
     hasDeviceLogic = "hipLaunchKernelGGL" in output_string
     hasDeviceLogic += "__global__" in output_string
     hasDeviceLogic += "__shared__" in output_string
-    hasDeviceLogic += re.search(r"[:]?[:]?\b(__syncthreads)\b(\w*\()", output_string) is not None
+    hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None
 
     # If device logic found, provide the necessary header.
     if hasDeviceLogic:
@@ -551,6 +558,9 @@ def hip_header_magic(input_string):
     return output_string
 
 
+RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;")
+
+
 def replace_extern_shared(input_string):
     """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
        https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#__shared__
@@ -559,8 +569,7 @@ def replace_extern_shared(input_string):
         "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
     """
     output_string = input_string
-    output_string = re.sub(
-        r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;",
+    output_string = RE_EXTERN_SHARED.sub(
         lambda inp: "HIP_DYNAMIC_SHARED({0} {1}, {2})".format(
             inp.group(1) or "", inp.group(2), inp.group(3)), output_string)
 
@@ -790,6 +799,84 @@ def is_caffe2_gpu_file(filepath):
     return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
 
 
+# Cribbed from https://stackoverflow.com/questions/42742810/speed-up-millions-of-regex-replacements-in-python-3/42789508#42789508
+class Trie():
+    """Regex::Trie in Python. Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
+    The corresponding Regex should match much faster than a simple Regex union."""
+
+    def __init__(self):
+        self.data = {}
+
+    def add(self, word):
+        ref = self.data
+        for char in word:
+            ref[char] = char in ref and ref[char] or {}
+            ref = ref[char]
+        ref[''] = 1
+
+    def dump(self):
+        return self.data
+
+    def quote(self, char):
+        return re.escape(char)
+
+    def _pattern(self, pData):
+        data = pData
+        if "" in data and len(data.keys()) == 1:
+            return None
+
+        alt = []
+        cc = []
+        q = 0
+        for char in sorted(data.keys()):
+            if isinstance(data[char], dict):
+                try:
+                    recurse = self._pattern(data[char])
+                    alt.append(self.quote(char) + recurse)
+                except:
+                    cc.append(self.quote(char))
+            else:
+                q = 1
+        cconly = not len(alt) > 0
+
+        if len(cc) > 0:
+            if len(cc) == 1:
+                alt.append(cc[0])
+            else:
+                alt.append('[' + ''.join(cc) + ']')
+
+        if len(alt) == 1:
+            result = alt[0]
+        else:
+            result = "(?:" + "|".join(alt) + ")"
+
+        if q:
+            if cconly:
+                result += "?"
+            else:
+                result = "(?:%s)?" % result
+        return result
+
+    def pattern(self):
+        return self._pattern(self.dump())
+
+
+CAFFE2_TRIE = Trie()
+CAFFE2_MAP = {}
+PYTORCH_TRIE = Trie()
+PYTORCH_MAP = {}
+for mapping in CUDA_TO_HIP_MAPPINGS:
+    for src, value in mapping.items():
+        dst = value[0]
+        meta_data = value[1:]
+        if constants.API_CAFFE2 not in meta_data:
+            PYTORCH_TRIE.add(src)
+            PYTORCH_MAP[src] = dst
+        CAFFE2_TRIE.add(src)
+        CAFFE2_MAP[src] = dst
+RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())
+RE_PYTORCH_PREPROCESSOR = re.compile(r'\b{0}\b'.format(PYTORCH_TRIE.pattern()))
+
 def preprocessor(output_directory, filepath, stats):
     """ Executes the CUDA -> HIP conversion on the specified file. """
     fin_path = os.path.join(output_directory, filepath)
@@ -801,27 +888,15 @@ def preprocessor(output_directory, filepath, stats):
         os.makedirs(os.path.dirname(fout_path))
 
     with open(fout_path, 'w') as fout:
-        # Perform type, method, constant replacements
-        for mapping in CUDA_TO_HIP_MAPPINGS:
-            for cuda_type, value in mapping.items():
-                # Extract relevant information
-                hip_type = value[0]
-                meta_data = value[1:]
-
-                if constants.API_CAFFE2 in meta_data and is_pytorch_file(filepath):
-                    continue
-
-                if output_source.find(cuda_type) > -1:
-                    # Check if supported
-                    if constants.HIP_UNSUPPORTED in meta_data:
-                        stats["unsupported_calls"].append((cuda_type, filepath))
-
-                if cuda_type in output_source:
-                    if is_pytorch_file(filepath):
-                        pattern = r'(\b{0}\b)'.format(re.escape(cuda_type))
-                    else:
-                        pattern = r'({0})'.format(re.escape(cuda_type))
-                    output_source = re.sub(pattern, hip_type, output_source)
+        # unsupported_calls statistics reporting is broken atm
+        if is_pytorch_file(filepath):
+            def pt_repl(m):
+                return PYTORCH_MAP[m.group(0)]
+            output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
+        else:
+            def c2_repl(m):
+                return CAFFE2_MAP[m.group(0)]
+            output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
 
         # Perform Kernel Launch Replacements
         output_source = processKernelLaunches(output_source, stats)
@@ -831,7 +906,7 @@ def preprocessor(output_directory, filepath, stats):
         #    output_source = disable_asserts(output_source)
 
         # Replace std:: with non-std:: versions
-        if re.search(r"\.cu$", filepath) or re.search(r"\.cuh$", filepath):
+        if filepath.endswith(".cu") or filepath.endswith(".cuh"):
           output_source = replace_math_functions(output_source)
 
         # Replace __forceinline__ with inline
@@ -894,6 +969,10 @@ def fix_static_global_kernels(in_txt):
 # get deleted soon.
 
 
+RE_KERNEL_TEMPLATE = re.compile(r"(template[ ]*<(.*)>\n.*\n?)?__global__ void[\n| ](\w+(\(.*\))?)\(")
+RE_GENERATE_KERNEL = re.compile(r"GENERATE_KERNEL([1-9])\((.*)\)")
+
+
 def get_kernel_template_params(output_directory, the_file, KernelTemplateParams, template_param_to_value):
     """Scan for __global__ kernel definitions then extract its argument types, and static cast as necessary"""
     # Read the kernel file.
@@ -901,8 +980,7 @@ def get_kernel_template_params(output_directory, the_file, KernelTemplateParams,
         # Extract all kernels with their templates inside of the file
         string = f.read()
 
-        get_kernel_definitions = [k for k in re.finditer(
-            r"(template[ ]*<(.*)>\n.*\n?)?__global__ void[\n| ](\w+(\(.*\))?)\(", string)]
+        get_kernel_definitions = [k for k in RE_KERNEL_TEMPLATE.finditer(string)]
 
         # Create new launch syntax
         for kernel in get_kernel_definitions:
@@ -967,7 +1045,7 @@ def get_kernel_template_params(output_directory, the_file, KernelTemplateParams,
 
         # Extract generated kernels
         # curandStateMtgp32 *state, int size, T *result, ARG1
-        for kernel in re.finditer(r"GENERATE_KERNEL([1-9])\((.*)\)", string):
+        for kernel in RE_GENERATE_KERNEL.finditer(string):
             kernel_gen_type = int(kernel.group(1))
             kernel_name = kernel.group(2).split(",")[0]
             kernel_params = kernel.group(2).split(",")[1:]
@@ -1020,11 +1098,14 @@ def disable_unsupported_function_call(function, input_string, replacement):
     return output_string
 
 
+RE_INCLUDE = re.compile(r"#include .*\n")
+
+
 def disable_module(input_file):
     """Disable a module entirely except for header includes."""
     with openf(input_file, "r+") as f:
         txt = f.read()
-        last = list(re.finditer(r"#include .*\n", txt))[-1]
+        last = list(RE_INCLUDE.finditer(txt))[-1]
         end = last.end()
 
         disabled = "{0}#if !defined(__HIP_PLATFORM_HCC__)\n{1}\n#endif".format(txt[0:end], txt[end:])
@@ -1080,6 +1161,9 @@ def extract_arguments(start, string):
     return arguments
 
 
+RE_HIP_LAUNCH_KERNEL_GGL = re.compile("hipLaunchKernelGGL\(")
+
+
 # Add static_cast to ensure that the type of kernel arguments matches that in the corresponding kernel definition
 def add_static_casts(orig_filepath, filepath, KernelTemplateParams):
     """Add static casts to kernel launches in order to keep launch argument types and kernel definition types matching.
@@ -1101,7 +1185,7 @@ def add_static_casts(orig_filepath, filepath, KernelTemplateParams):
     with openf(filepath, "r+") as fileobj:
         input_source = fileobj.read()
         new_output_source = input_source
-        for kernel in re.finditer("hipLaunchKernelGGL\(", input_source):
+        for kernel in RE_HIP_LAUNCH_KERNEL_GGL.finditer(input_source):
             arguments = extract_arguments(kernel.end() - 1, input_source)
 
             # Check if we have templating + static_cast information