From a9983ac09c9772fa426224b080090083b4c9607b Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Fri, 27 Aug 2021 09:37:10 -0700 Subject: [PATCH] Refactor structured set_output in Register{DispatchKey}.cpp (#62188) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62188 These parts of the `set_output` code are identical for all operators in the kernel registration files. So, this moves them from being copied into every class to two helper functions at the top of the file. Test Plan: Imported from OSS Reviewed By: soulitzer Differential Revision: D29962045 Pulled By: albanD fbshipit-source-id: 753b8aac755f3c91b77ffa2c30a89ac91a84b7c4 --- aten/src/ATen/templates/RegisterDispatchKey.cpp | 2 + tools/codegen/dest/__init__.py | 1 + tools/codegen/dest/register_dispatch_key.py | 130 ++++++++++++++---------- tools/codegen/gen.py | 1 + tools/codegen/gen_backend_stubs.py | 1 + 5 files changed, 84 insertions(+), 51 deletions(-) diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index 1abc3ee..16caf53 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -43,6 +43,8 @@ namespace at { // at namespace already. namespace { +${dispatch_helpers} + ${dispatch_anonymous_definitions} TORCH_LIBRARY_IMPL(aten, ${DispatchKey}, m) { diff --git a/tools/codegen/dest/__init__.py b/tools/codegen/dest/__init__.py index ab4bada..441e442 100644 --- a/tools/codegen/dest/__init__.py +++ b/tools/codegen/dest/__init__.py @@ -1,2 +1,3 @@ from .register_dispatch_key import RegisterDispatchKey as RegisterDispatchKey +from .register_dispatch_key import gen_registration_helpers as gen_registration_helpers from .native_functions import compute_native_function_declaration as compute_native_function_declaration diff --git a/tools/codegen/dest/register_dispatch_key.py b/tools/codegen/dest/register_dispatch_key.py index a943f51..784ee56 100644 --- a/tools/codegen/dest/register_dispatch_key.py +++ b/tools/codegen/dest/register_dispatch_key.py @@ -23,6 +23,79 @@ import tools.codegen.api.structured as structured from tools.codegen.api.translate import translate from tools.codegen.selective_build.selector import SelectiveBuilder + +def gen_create_out_helper(backend_index: BackendIndex) -> List[str]: + if backend_index.dispatch_key == DispatchKey.Meta: + # TODO: dedupe this with below + core = """ +if (strides.empty()) { + return at::empty(sizes, options.device(at::kMeta)); +} else { + return at::empty_strided(sizes, strides, options.device(at::kMeta)); +} +""" + else: + expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \ + "options.device_opt(), options.pinned_memory_opt()" + empty_init = "" + if backend_index.dispatch_key == DispatchKey.CPU: + empty_impl = "at::native::empty_cpu" + empty_strided_impl = "at::native::empty_strided_cpu" + elif backend_index.dispatch_key == DispatchKey.CUDA: + empty_init = "globalContext().lazyInitCUDA();" + empty_impl = "at::native::empty_cuda" + empty_strided_impl = "at::native::empty_strided_cuda" + elif backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: + empty_impl = "at::empty" + empty_strided_impl = "at::empty_strided" + else: + return [] + core = f""" + {empty_init} + if (strides.empty()) {{ + return {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt()); + }} else {{ + // TODO: assert options.memory_format_opt() is nullopt (debug only?) + return {empty_strided_impl}(sizes, strides, {expanded_topts}); + }} +""" + return [f""" +Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ +{core} +}} +"""] + + +def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]: + return [""" +void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) { + TORCH_CHECK(options.dtype() == out.dtype(), + "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead"); + TORCH_CHECK(options.device() == out.device(), + "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead"); + const bool resized = at::native::resize_output(out, sizes); + // Only restride if a resize occurred; otherwise we ignore the (advisory) + // strides from the meta function and directly use the output tensor's + // preexisting strides + if (resized) { + if (!strides.empty()) { + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + at::native::as_strided_(out, sizes, strides); + } else if (options.memory_format_opt().has_value()) { + out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); + } + } +} +"""] + + +def gen_registration_helpers(backend_index: BackendIndex) -> List[str]: + return [ + *gen_create_out_helper(backend_index), + *gen_resize_out_helper(backend_index) + ] + + # Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp). # # - The primary function of this file is to register all of the @@ -344,62 +417,17 @@ if (C10_UNLIKELY(current_device.has_value())) { maybe_set_guard_line = maybe_set_guard = '' if k is SchemaKind.functional: - if self.backend_index.dispatch_key == DispatchKey.Meta: - # TODO: dedupe this with below - return """ -if (strides.empty()) { - outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta)); -} else { - outputs_[output_idx] = at::empty_strided(sizes, strides, options.device(at::kMeta)); -} -""" - else: - expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \ - "options.device_opt(), options.pinned_memory_opt()" - empty_init = "" - if self.backend_index.dispatch_key == DispatchKey.CPU: - empty_impl = "at::native::empty_cpu" - empty_strided_impl = "at::native::empty_strided_cpu" - elif self.backend_index.dispatch_key == DispatchKey.CUDA: - empty_init = "globalContext().lazyInitCUDA();" - empty_impl = "at::native::empty_cuda" - empty_strided_impl = "at::native::empty_strided_cuda" - elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: - empty_impl = "at::empty" - empty_strided_impl = "at::empty_strided" - else: - raise AssertionError("unsupported dispatch key") - return f"""{maybe_set_guard_line} -{empty_init} -if (strides.empty()) {{ - outputs_[output_idx] = {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt()); -}} else {{ - // TODO: assert options.memory_format_opt() is nullopt (debug only?) - outputs_[output_idx] = {empty_strided_impl}(sizes, strides, {expanded_topts}); -}} -""" + assert self.backend_index.dispatch_key in ( + DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA, + DispatchKey.CompositeExplicitAutograd) + return f"""{maybe_set_guard_line} +outputs_[output_idx] = create_out(sizes, strides, options);""" elif k is SchemaKind.inplace: return maybe_set_guard elif k is SchemaKind.out: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get(); -TORCH_CHECK(options.dtype() == out.dtype(), - "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead"); -TORCH_CHECK(options.device() == out.device(), - "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead"); -bool resized = at::native::resize_output(outputs_[output_idx], sizes); -// Only restride if a resize occurred; otherwise we ignore the (advisory) -// strides from the meta function and directly use the output tensor's -// preexisting strides -if (resized) {{ - if (!strides.empty()) {{ - TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); - at::native::as_strided_(outputs_[output_idx], sizes, strides); - }} else if (options.memory_format_opt().has_value()) {{ - outputs_[output_idx].get().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); - }} -}} -""" +resize_out(out, sizes, strides, options);""" else: assert_never(k) diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index ffa4ed7..203b5a9 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -1100,6 +1100,7 @@ def main() -> None: 'namespaced_headers': f'#include ' if dispatch_key in functions_keys else '', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), + 'dispatch_helpers': dest.gen_registration_helpers(backend_indices[dispatch_key]), 'dispatch_namespaced_definitions': list(concatMap( dest.RegisterDispatchKey( backend_indices[dispatch_key], diff --git a/tools/codegen/gen_backend_stubs.py b/tools/codegen/gen_backend_stubs.py index 51f81c7..5fad11c 100644 --- a/tools/codegen/gen_backend_stubs.py +++ b/tools/codegen/gen_backend_stubs.py @@ -231,6 +231,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool, impl_path: Optional[st 'namespaced_headers': '', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), + 'dispatch_helpers': dest.gen_registration_helpers(backend_indices[dispatch_key]), 'dispatch_namespaced_definitions': list(concatMap( dest.RegisterDispatchKey( backend_indices[dispatch_key], -- 2.7.4