from tools.codegen.api.translate import translate
from tools.codegen.selective_build.selector import SelectiveBuilder
+
+def gen_create_out_helper(backend_index: BackendIndex) -> List[str]:
+ if backend_index.dispatch_key == DispatchKey.Meta:
+ # TODO: dedupe this with below
+ core = """
+if (strides.empty()) {
+ return at::empty(sizes, options.device(at::kMeta));
+} else {
+ return at::empty_strided(sizes, strides, options.device(at::kMeta));
+}
+"""
+ else:
+ expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \
+ "options.device_opt(), options.pinned_memory_opt()"
+ empty_init = ""
+ if backend_index.dispatch_key == DispatchKey.CPU:
+ empty_impl = "at::native::empty_cpu"
+ empty_strided_impl = "at::native::empty_strided_cpu"
+ elif backend_index.dispatch_key == DispatchKey.CUDA:
+ empty_init = "globalContext().lazyInitCUDA();"
+ empty_impl = "at::native::empty_cuda"
+ empty_strided_impl = "at::native::empty_strided_cuda"
+ elif backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
+ empty_impl = "at::empty"
+ empty_strided_impl = "at::empty_strided"
+ else:
+ return []
+ core = f"""
+ {empty_init}
+ if (strides.empty()) {{
+ return {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt());
+ }} else {{
+ // TODO: assert options.memory_format_opt() is nullopt (debug only?)
+ return {empty_strided_impl}(sizes, strides, {expanded_topts});
+ }}
+"""
+ return [f"""
+Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
+{core}
+}}
+"""]
+
+
+def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]:
+ return ["""
+void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
+ TORCH_CHECK(options.dtype() == out.dtype(),
+ "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
+ TORCH_CHECK(options.device() == out.device(),
+ "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
+ const bool resized = at::native::resize_output(out, sizes);
+ // Only restride if a resize occurred; otherwise we ignore the (advisory)
+ // strides from the meta function and directly use the output tensor's
+ // preexisting strides
+ if (resized) {
+ if (!strides.empty()) {
+ TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
+ at::native::as_strided_(out, sizes, strides);
+ } else if (options.memory_format_opt().has_value()) {
+ out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
+ }
+ }
+}
+"""]
+
+
+def gen_registration_helpers(backend_index: BackendIndex) -> List[str]:
+ return [
+ *gen_create_out_helper(backend_index),
+ *gen_resize_out_helper(backend_index)
+ ]
+
+
# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
#
# - The primary function of this file is to register all of the
maybe_set_guard_line = maybe_set_guard = ''
if k is SchemaKind.functional:
- if self.backend_index.dispatch_key == DispatchKey.Meta:
- # TODO: dedupe this with below
- return """
-if (strides.empty()) {
- outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta));
-} else {
- outputs_[output_idx] = at::empty_strided(sizes, strides, options.device(at::kMeta));
-}
-"""
- else:
- expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \
- "options.device_opt(), options.pinned_memory_opt()"
- empty_init = ""
- if self.backend_index.dispatch_key == DispatchKey.CPU:
- empty_impl = "at::native::empty_cpu"
- empty_strided_impl = "at::native::empty_strided_cpu"
- elif self.backend_index.dispatch_key == DispatchKey.CUDA:
- empty_init = "globalContext().lazyInitCUDA();"
- empty_impl = "at::native::empty_cuda"
- empty_strided_impl = "at::native::empty_strided_cuda"
- elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
- empty_impl = "at::empty"
- empty_strided_impl = "at::empty_strided"
- else:
- raise AssertionError("unsupported dispatch key")
- return f"""{maybe_set_guard_line}
-{empty_init}
-if (strides.empty()) {{
- outputs_[output_idx] = {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt());
-}} else {{
- // TODO: assert options.memory_format_opt() is nullopt (debug only?)
- outputs_[output_idx] = {empty_strided_impl}(sizes, strides, {expanded_topts});
-}}
-"""
+ assert self.backend_index.dispatch_key in (
+ DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA,
+ DispatchKey.CompositeExplicitAutograd)
+ return f"""{maybe_set_guard_line}
+outputs_[output_idx] = create_out(sizes, strides, options);"""
elif k is SchemaKind.inplace:
return maybe_set_guard
elif k is SchemaKind.out:
return f"""{maybe_set_guard_line}
const auto& out = outputs_[output_idx].get();
-TORCH_CHECK(options.dtype() == out.dtype(),
- "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
-TORCH_CHECK(options.device() == out.device(),
- "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
-bool resized = at::native::resize_output(outputs_[output_idx], sizes);
-// Only restride if a resize occurred; otherwise we ignore the (advisory)
-// strides from the meta function and directly use the output tensor's
-// preexisting strides
-if (resized) {{
- if (!strides.empty()) {{
- TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
- at::native::as_strided_(outputs_[output_idx], sizes, strides);
- }} else if (options.memory_format_opt().has_value()) {{
- outputs_[output_idx].get().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
- }}
-}}
-"""
+resize_out(out, sizes, strides, options);"""
else:
assert_never(k)