])
-def device_guard(option, formals, is_factory_method, dispatch_options, dispatch_tensor):
+def device_guard(option, formals, dispatch_options, dispatch_tensor):
# For factory methods the `DeviceGuard` is already in the template.
if option.get('device_guard', True):
if dispatch_options:
return 'const DeviceGuard device_guard({}.device());'.format(dispatch_options['name'])
- if not is_factory_method and dispatch_tensor:
+ if dispatch_tensor:
return 'const OptionalDeviceGuard device_guard(device_of({}));'.format(dispatch_tensor)
return '// DeviceGuard omitted'
option['method_prefix_derived'] = '' if broadcast_arg is None else 's_'
if option['mode'] == 'TH':
option['device_guard'] = False
- option['device_guard_declaration'] = device_guard(option, formals, False, False, dispatch_tensor)
+ option['device_guard_declaration'] = device_guard(option, formals, False, dispatch_tensor)
env = nested_dict(option, top_env)
option['name'], ", ".join(option['method_formals_with_defaults']))
type_method_dispatch = option['type_method_definition_dispatch']
- backend_dispatch = isinstance(type_method_dispatch, dict)
- # We only dispatch via options if there is backend-specific dispatch
- # (otherwise it's a factory function that can dispatch directly to the
- # native function).
- dispatch_options = (find_formal('TensorOptions', formals)
- if backend_dispatch else None)
+ dispatch_options = find_formal('TensorOptions', formals)
# Only dispatch via tensor if there is no Options argument
dispatch_tensor = None if dispatch_options else find_dispatch_tensor(formals)
check_methods_do_not_start_with_underscore(option['name'], is_method)
option['method_prefix_derived'] = ''
- option['device_guard_declaration'] = device_guard(option, formals, is_factory_method,
- dispatch_options, dispatch_tensor)
+ option['device_guard_declaration'] = device_guard(option, formals, dispatch_options, dispatch_tensor)
env = nested_dict(option, top_env)
raise Exception("broadcasting is not yet supported for native functions, "
"but specified for function {}", option['name'])
- # Factory methods are not dispatched over `Type`.
- if not is_factory_method:
- if option['extended_method']:
- top_env['pure_virtual_extended_type_method_declarations'].append(
- PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
- else:
- top_env['pure_virtual_type_method_declarations'].append(
- PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
- top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
+ if option['extended_method']:
+ top_env['pure_virtual_extended_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ else:
+ top_env['pure_virtual_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
option['native_type_method_dispatch'] = type_method_dispatch
# Note [Abstract ATen methods]
abstract = True
top_env['type_method_definitions'].append(
TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
- elif not is_factory_method:
+ else:
body = TYPE_DEFINITION_BODY_NATIVE.substitute(env)
top_env['type_method_definitions'].append(
TYPE_METHOD_DEFINITION_CONCRETE.substitute(
option['inferred_type'] = 'at::getNonVariableType(at::Backend::Undefined, at::ScalarType::Float)'
declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
top_env['function_declarations'].append(declaration.substitute(env))
- if is_factory_method:
- top_env['function_definitions'].append(FACTORY_DEFINITION.substitute(env))
- else:
- top_env['function_definitions'].append(FUNCTION_DEFINITION.substitute(env))
+ top_env['function_definitions'].append(FUNCTION_DEFINITION.substitute(env))
method_of.append('namespace')
output_options.append(OutputDeclaration(