${type_definition_body}
}
""")
-DEPRECATED_TYPE_METHOD_DEFINITION_CONCRETE = CodeTemplate("""\
-${return_type} TypeDefault::${api_name}(${type_method_formals}) const {
- ${device_guard_declaration}
- return at::native::${api_name}(${type_method_actuals}, options());
-}
-""")
# 4. add override to TypeDerived.h
TYPE_DERIVED_DECLARATION = CodeTemplate("""\
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;
}
""")
-# special method definition for *deprecated* factory functions in Functions.h
-DEPRECATED_FACTORY_DEFINITION = CodeTemplate("""\
-static inline ${return_type} ${api_name}(${formals}) {
- return at::${api_name}(${type_method_actuals}, ${inferred_type}.options());
-}
-""")
-
# We need to cast to the base type because C++ may hide the base class
# implementation of ${api_name} if we have overloaded a function with
# the same name (but different signature) already
is_namespace_function = 'function' in option['variants']
is_factory_method = find_formal('TensorOptions', formals) and \
not dispatch_options and 'method' not in option['variants']
- is_deprecated_factory_method = len(formals) > 0 and \
- formals[0]['dynamic_type'] == 'Type' and \
- option['return_type'] == 'Tensor' and option['deprecated']
- needs_native_definition = not is_deprecated_factory_method
check_methods_do_not_start_with_underscore(option['name'], is_method)
abstract = True
top_env['type_method_definitions'].append(
TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
- elif is_deprecated_factory_method:
- top_env['type_method_definitions'].append(
- DEPRECATED_TYPE_METHOD_DEFINITION_CONCRETE.substitute(env))
elif not is_factory_method:
body = TYPE_DEFINITION_BODY_NATIVE.substitute(env)
top_env['type_method_definitions'].append(
env, type_definition_body=body))
# generate the at::native function declarations (i.e. what the user will implement)
- if needs_native_definition:
- if isinstance(type_method_dispatch, dict):
- generated_native_functions = [] # type: List[str]
- for key in sorted(type_method_dispatch.keys()):
- value = type_method_dispatch[key]
- if value not in generated_native_functions:
- option['native_type_method_dispatch'] = value
- top_env['native_function_declarations'].append(
- NATIVE_DECLARATION.substitute(env))
- generated_native_functions.append(value)
- else:
- top_env['native_function_declarations'].append(
- NATIVE_DECLARATION.substitute(env))
+ if isinstance(type_method_dispatch, dict):
+ generated_native_functions = [] # type: List[str]
+ for key in sorted(type_method_dispatch.keys()):
+ value = type_method_dispatch[key]
+ if value not in generated_native_functions:
+ option['native_type_method_dispatch'] = value
+ top_env['native_function_declarations'].append(
+ NATIVE_DECLARATION.substitute(env))
+ generated_native_functions.append(value)
+ else:
+ top_env['native_function_declarations'].append(
+ NATIVE_DECLARATION.substitute(env))
method_of = ['Type']
if is_method:
top_env['function_declarations'].append(declaration.substitute(env))
if is_factory_method:
top_env['function_definitions'].append(FACTORY_DEFINITION.substitute(env))
- elif is_deprecated_factory_method:
- top_env['function_definitions'].append(DEPRECATED_FACTORY_DEFINITION.substitute(env))
else:
top_env['function_definitions'].append(FUNCTION_DEFINITION.substitute(env))
method_of.append('namespace')