Refactor structured set_output in Register{DispatchKey}.cpp (#62188)
authorPeter Bell <peterbell10@live.co.uk>
Fri, 27 Aug 2021 16:37:10 +0000 (09:37 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 16:38:27 +0000 (09:38 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62188

These parts of the `set_output` code are identical for all operators in the
kernel registration files. So, this moves them from being copied into every
class to two helper functions at the top of the file.

Test Plan: Imported from OSS

Reviewed By: soulitzer

Differential Revision: D29962045

Pulled By: albanD

fbshipit-source-id: 753b8aac755f3c91b77ffa2c30a89ac91a84b7c4

aten/src/ATen/templates/RegisterDispatchKey.cpp
tools/codegen/dest/__init__.py
tools/codegen/dest/register_dispatch_key.py
tools/codegen/gen.py
tools/codegen/gen_backend_stubs.py

index 1abc3ee..16caf53 100644 (file)
@@ -43,6 +43,8 @@ namespace at {
 // at namespace already.
 namespace {
 
+${dispatch_helpers}
+
 ${dispatch_anonymous_definitions}
 
 TORCH_LIBRARY_IMPL(aten, ${DispatchKey}, m) {
index ab4bada..441e442 100644 (file)
@@ -1,2 +1,3 @@
 from .register_dispatch_key import RegisterDispatchKey as RegisterDispatchKey
+from .register_dispatch_key import gen_registration_helpers as gen_registration_helpers
 from .native_functions import compute_native_function_declaration as compute_native_function_declaration
index a943f51..784ee56 100644 (file)
@@ -23,6 +23,79 @@ import tools.codegen.api.structured as structured
 from tools.codegen.api.translate import translate
 from tools.codegen.selective_build.selector import SelectiveBuilder
 
+
+def gen_create_out_helper(backend_index: BackendIndex) -> List[str]:
+    if backend_index.dispatch_key == DispatchKey.Meta:
+        # TODO: dedupe this with below
+        core = """
+if (strides.empty()) {
+    return at::empty(sizes, options.device(at::kMeta));
+} else {
+    return at::empty_strided(sizes, strides, options.device(at::kMeta));
+}
+"""
+    else:
+        expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \
+            "options.device_opt(), options.pinned_memory_opt()"
+        empty_init = ""
+        if backend_index.dispatch_key == DispatchKey.CPU:
+            empty_impl = "at::native::empty_cpu"
+            empty_strided_impl = "at::native::empty_strided_cpu"
+        elif backend_index.dispatch_key == DispatchKey.CUDA:
+            empty_init = "globalContext().lazyInitCUDA();"
+            empty_impl = "at::native::empty_cuda"
+            empty_strided_impl = "at::native::empty_strided_cuda"
+        elif backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
+            empty_impl = "at::empty"
+            empty_strided_impl = "at::empty_strided"
+        else:
+            return []
+        core = f"""
+  {empty_init}
+  if (strides.empty()) {{
+      return {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt());
+  }} else {{
+      // TODO: assert options.memory_format_opt() is nullopt (debug only?)
+      return {empty_strided_impl}(sizes, strides, {expanded_topts});
+  }}
+"""
+    return [f"""
+Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
+{core}
+}}
+"""]
+
+
+def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]:
+    return ["""
+void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
+  TORCH_CHECK(options.dtype() == out.dtype(),
+      "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
+  TORCH_CHECK(options.device() == out.device(),
+      "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
+  const bool resized = at::native::resize_output(out, sizes);
+  // Only restride if a resize occurred; otherwise we ignore the (advisory)
+  // strides from the meta function and directly use the output tensor's
+  // preexisting strides
+  if (resized) {
+    if (!strides.empty()) {
+      TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
+      at::native::as_strided_(out, sizes, strides);
+    } else if (options.memory_format_opt().has_value()) {
+      out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
+    }
+  }
+}
+"""]
+
+
+def gen_registration_helpers(backend_index: BackendIndex) -> List[str]:
+    return [
+        *gen_create_out_helper(backend_index),
+        *gen_resize_out_helper(backend_index)
+    ]
+
+
 # Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
 #
 #   - The primary function of this file is to register all of the
@@ -344,62 +417,17 @@ if (C10_UNLIKELY(current_device.has_value())) {
             maybe_set_guard_line = maybe_set_guard = ''
 
         if k is SchemaKind.functional:
-            if self.backend_index.dispatch_key == DispatchKey.Meta:
-                # TODO: dedupe this with below
-                return """
-if (strides.empty()) {
-    outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta));
-} else {
-    outputs_[output_idx] = at::empty_strided(sizes, strides, options.device(at::kMeta));
-}
-"""
-            else:
-                expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \
-                    "options.device_opt(), options.pinned_memory_opt()"
-                empty_init = ""
-                if self.backend_index.dispatch_key == DispatchKey.CPU:
-                    empty_impl = "at::native::empty_cpu"
-                    empty_strided_impl = "at::native::empty_strided_cpu"
-                elif self.backend_index.dispatch_key == DispatchKey.CUDA:
-                    empty_init = "globalContext().lazyInitCUDA();"
-                    empty_impl = "at::native::empty_cuda"
-                    empty_strided_impl = "at::native::empty_strided_cuda"
-                elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
-                    empty_impl = "at::empty"
-                    empty_strided_impl = "at::empty_strided"
-                else:
-                    raise AssertionError("unsupported dispatch key")
-                return f"""{maybe_set_guard_line}
-{empty_init}
-if (strides.empty()) {{
-    outputs_[output_idx] = {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt());
-}} else {{
-    // TODO: assert options.memory_format_opt() is nullopt (debug only?)
-    outputs_[output_idx] = {empty_strided_impl}(sizes, strides, {expanded_topts});
-}}
-"""
+            assert self.backend_index.dispatch_key in (
+                DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA,
+                DispatchKey.CompositeExplicitAutograd)
+            return f"""{maybe_set_guard_line}
+outputs_[output_idx] = create_out(sizes, strides, options);"""
         elif k is SchemaKind.inplace:
             return maybe_set_guard
         elif k is SchemaKind.out:
             return f"""{maybe_set_guard_line}
 const auto& out = outputs_[output_idx].get();
-TORCH_CHECK(options.dtype() == out.dtype(),
-    "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
-TORCH_CHECK(options.device() == out.device(),
-    "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
-bool resized = at::native::resize_output(outputs_[output_idx], sizes);
-// Only restride if a resize occurred; otherwise we ignore the (advisory)
-// strides from the meta function and directly use the output tensor's
-// preexisting strides
-if (resized) {{
-    if (!strides.empty()) {{
-        TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
-        at::native::as_strided_(outputs_[output_idx], sizes, strides);
-    }} else if (options.memory_format_opt().has_value()) {{
-        outputs_[output_idx].get().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
-    }}
-}}
-"""
+resize_out(out, sizes, strides, options);"""
         else:
             assert_never(k)
 
index ffa4ed7..203b5a9 100644 (file)
@@ -1100,6 +1100,7 @@ def main() -> None:
             'namespaced_headers': f'#include <ATen/{dispatch_key}Functions.h>' if dispatch_key in functions_keys else '',
             'DispatchKey': dispatch_key,
             'dispatch_namespace': dispatch_key.lower(),
+            'dispatch_helpers': dest.gen_registration_helpers(backend_indices[dispatch_key]),
             'dispatch_namespaced_definitions': list(concatMap(
                 dest.RegisterDispatchKey(
                     backend_indices[dispatch_key],
index 51f81c7..5fad11c 100644 (file)
@@ -231,6 +231,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool, impl_path: Optional[st
                 'namespaced_headers': '',
                 'DispatchKey': dispatch_key,
                 'dispatch_namespace': dispatch_key.lower(),
+                'dispatch_helpers': dest.gen_registration_helpers(backend_indices[dispatch_key]),
                 'dispatch_namespaced_definitions': list(concatMap(
                     dest.RegisterDispatchKey(
                         backend_indices[dispatch_key],