Dispatch factory functions on Type (#15093)
authorRoy Li <royboy@fb.com>
Fri, 1 Feb 2019 18:54:59 +0000 (10:54 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 1 Feb 2019 19:00:15 +0000 (11:00 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15093

Needed for backend extensions.

Reviewed By: ezyang

Differential Revision: D13427897

fbshipit-source-id: d0b34b0072e597ae599bd3bc25356831d7a18d6a

aten/src/ATen/function_wrapper.py

index 7087bf5..4978f8e 100644 (file)
@@ -547,12 +547,12 @@ OutputDeclaration = NamedTuple('OutputDeclaration', [
 ])
 
 
-def device_guard(option, formals, is_factory_method, dispatch_options, dispatch_tensor):
+def device_guard(option, formals, dispatch_options, dispatch_tensor):
     # For factory methods the `DeviceGuard` is already in the template.
     if option.get('device_guard', True):
         if dispatch_options:
             return 'const DeviceGuard device_guard({}.device());'.format(dispatch_options['name'])
-        if not is_factory_method and dispatch_tensor:
+        if dispatch_tensor:
             return 'const OptionalDeviceGuard device_guard(device_of({}));'.format(dispatch_tensor)
     return '// DeviceGuard omitted'
 
@@ -836,7 +836,7 @@ def create_generic(top_env, declarations):
         option['method_prefix_derived'] = '' if broadcast_arg is None else 's_'
         if option['mode'] == 'TH':
             option['device_guard'] = False
-        option['device_guard_declaration'] = device_guard(option, formals, False, False, dispatch_tensor)
+        option['device_guard_declaration'] = device_guard(option, formals, False, dispatch_tensor)
 
         env = nested_dict(option, top_env)
 
@@ -1057,13 +1057,8 @@ def create_generic(top_env, declarations):
                 option['name'], ", ".join(option['method_formals_with_defaults']))
 
         type_method_dispatch = option['type_method_definition_dispatch']
-        backend_dispatch = isinstance(type_method_dispatch, dict)
 
-        # We only dispatch via options if there is backend-specific dispatch
-        # (otherwise it's a factory function that can dispatch directly to the
-        # native function).
-        dispatch_options = (find_formal('TensorOptions', formals)
-                            if backend_dispatch else None)
+        dispatch_options = find_formal('TensorOptions', formals)
         # Only dispatch via tensor if there is no Options argument
         dispatch_tensor = None if dispatch_options else find_dispatch_tensor(formals)
 
@@ -1081,8 +1076,7 @@ def create_generic(top_env, declarations):
         check_methods_do_not_start_with_underscore(option['name'], is_method)
 
         option['method_prefix_derived'] = ''
-        option['device_guard_declaration'] = device_guard(option, formals, is_factory_method,
-                                                          dispatch_options, dispatch_tensor)
+        option['device_guard_declaration'] = device_guard(option, formals, dispatch_options, dispatch_tensor)
 
         env = nested_dict(option, top_env)
 
@@ -1091,15 +1085,13 @@ def create_generic(top_env, declarations):
             raise Exception("broadcasting is not yet supported for native functions, "
                             "but specified for function {}", option['name'])
 
-        # Factory methods are not dispatched over `Type`.
-        if not is_factory_method:
-            if option['extended_method']:
-                top_env['pure_virtual_extended_type_method_declarations'].append(
-                    PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
-            else:
-                top_env['pure_virtual_type_method_declarations'].append(
-                    PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
-            top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
+        if option['extended_method']:
+            top_env['pure_virtual_extended_type_method_declarations'].append(
+                PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+        else:
+            top_env['pure_virtual_type_method_declarations'].append(
+                PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+        top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
         option['native_type_method_dispatch'] = type_method_dispatch
 
         # Note [Abstract ATen methods]
@@ -1115,7 +1107,7 @@ def create_generic(top_env, declarations):
             abstract = True
             top_env['type_method_definitions'].append(
                 TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
-        elif not is_factory_method:
+        else:
             body = TYPE_DEFINITION_BODY_NATIVE.substitute(env)
             top_env['type_method_definitions'].append(
                 TYPE_METHOD_DEFINITION_CONCRETE.substitute(
@@ -1153,10 +1145,7 @@ def create_generic(top_env, declarations):
                 option['inferred_type'] = 'at::getNonVariableType(at::Backend::Undefined, at::ScalarType::Float)'
             declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
             top_env['function_declarations'].append(declaration.substitute(env))
-            if is_factory_method:
-                top_env['function_definitions'].append(FACTORY_DEFINITION.substitute(env))
-            else:
-                top_env['function_definitions'].append(FUNCTION_DEFINITION.substitute(env))
+            top_env['function_definitions'].append(FUNCTION_DEFINITION.substitute(env))
             method_of.append('namespace')
 
         output_options.append(OutputDeclaration(