Add optional tensor arguments to (#63967)
authorSamantha Andow <samdow@fb.com>
Tue, 31 Aug 2021 02:15:16 +0000 (19:15 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 02:21:26 +0000 (19:21 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63435

Adds optional tensor arguments to check handling torch function checks. The only one I didn't do this for in the functional file was `multi_head_attention_forward` since that already took care of some optional tensor arguments but not others so it seemed like arguments were specifically chosen

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63967

Reviewed By: albanD

Differential Revision: D30640441

Pulled By: ezyang

fbshipit-source-id: 5ef9554d2fb6c14779f8f45542ab435fb49e5d0f

torch/nn/functional.py

index c11e261..4b0449c 100644 (file)
@@ -442,10 +442,10 @@ def fractional_max_pool2d_with_indices(
     .. _Fractional MaxPooling:
         http://arxiv.org/abs/1412.6071
     """
-    if has_torch_function_unary(input):
+    if has_torch_function_variadic(input, _random_samples):
         return handle_torch_function(
             fractional_max_pool2d_with_indices,
-            (input,),
+            (input, _random_samples),
             input,
             kernel_size,
             output_size=output_size,
@@ -473,10 +473,10 @@ def _fractional_max_pool2d(
     return_indices: bool = False,
     _random_samples: Optional[Tensor] = None
 ) -> Tensor:
-    if has_torch_function_unary(input):
+    if has_torch_function_variadic(input, _random_samples):
         return handle_torch_function(
             fractional_max_pool2d,
-            (input,),
+            (input, _random_samples),
             input,
             kernel_size,
             output_size=output_size,
@@ -537,10 +537,10 @@ def fractional_max_pool3d_with_indices(
     .. _Fractional MaxPooling:
         http://arxiv.org/abs/1412.6071
     """
-    if has_torch_function_unary(input):
+    if has_torch_function_variadic(input, _random_samples):
         return handle_torch_function(
             fractional_max_pool3d_with_indices,
-            (input,),
+            (input, _random_samples),
             input,
             kernel_size,
             output_size=output_size,
@@ -571,10 +571,10 @@ def _fractional_max_pool3d(
     return_indices: bool = False,
     _random_samples: Optional[Tensor] = None
 ) -> Tensor:
-    if has_torch_function_unary(input):
+    if has_torch_function_variadic(input, _random_samples):
         return handle_torch_function(
             fractional_max_pool3d,
-            (input,),
+            (input, _random_samples),
             input,
             kernel_size,
             output_size=output_size,
@@ -1843,8 +1843,8 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
         - Bias: :math:`(out\_features)`
         - Output: :math:`(N, *, out\_features)`
     """
-    if has_torch_function_variadic(input, weight):
-        return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
+    if has_torch_function_variadic(input, weight, bias):
+        return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
     return torch._C._nn.linear(input, weight, bias)
 
 
@@ -1865,10 +1865,10 @@ def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tens
         - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}`
           and all but the last dimension are the same shape as the input.
     """
-    if has_torch_function_variadic(input1, input2, weight):
+    if has_torch_function_variadic(input1, input2, weight, bias):
         return handle_torch_function(
             bilinear,
-            (input1, input2, weight),
+            (input1, input2, weight, bias),
             input1, input2, weight,
             bias=bias
         )
@@ -2135,10 +2135,10 @@ def embedding_bag(
         tensor([[ 0.0000,  0.0000,  0.0000],
                 [-0.7082,  3.2145, -2.6251]])
     """
-    if has_torch_function_variadic(input, weight):
+    if has_torch_function_variadic(input, weight, offsets, per_sample_weights):
         return handle_torch_function(
             embedding_bag,
-            (input, weight),
+            (input, weight, offsets, per_sample_weights),
             input,
             weight,
             offsets=offsets,
@@ -2263,10 +2263,10 @@ def batch_norm(
     See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
     :class:`~torch.nn.BatchNorm3d` for details.
     """
-    if has_torch_function_unary(input):
+    if has_torch_function_variadic(input, running_mean, running_var, weight, bias):
         return handle_torch_function(
             batch_norm,
-            (input,),
+            (input, running_mean, running_var, weight, bias),
             input,
             running_mean,
             running_var,
@@ -2309,10 +2309,10 @@ def instance_norm(
     See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`,
     :class:`~torch.nn.InstanceNorm3d` for details.
     """
-    if has_torch_function_unary(input):
+    if has_torch_function_variadic(input, running_mean, running_var, weight, bias):
         return handle_torch_function(
             instance_norm,
-            (input,),
+            (input, running_mean, running_var, weight, bias),
             input,
             running_mean=running_mean,
             running_var=running_var,
@@ -2340,9 +2340,9 @@ def layer_norm(
 
     See :class:`~torch.nn.LayerNorm` for details.
     """
-    if has_torch_function_unary(input):
+    if has_torch_function_variadic(input, weight, bias):
         return handle_torch_function(
-            layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps
+            layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
         )
     return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
 
@@ -2354,8 +2354,8 @@ def group_norm(
 
     See :class:`~torch.nn.GroupNorm` for details.
     """
-    if has_torch_function_unary(input):
-        return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps)
+    if has_torch_function_variadic(input, weight, bias):
+        return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps)
     _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:]))
     return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)
 
@@ -2515,10 +2515,10 @@ def nll_loss(
         >>> output = F.nll_loss(F.log_softmax(input), target)
         >>> output.backward()
     """
-    if has_torch_function_variadic(input, target):
+    if has_torch_function_variadic(input, target, weight):
         return handle_torch_function(
             nll_loss,
-            (input, target),
+            (input, target, weight),
             input,
             target,
             weight=weight,
@@ -2828,10 +2828,10 @@ def cross_entropy(
         >>> loss = F.cross_entropy(input, target)
         >>> loss.backward()
     """
-    if has_torch_function_variadic(input, target):
+    if has_torch_function_variadic(input, target, weight):
         return handle_torch_function(
             cross_entropy,
-            (input, target),
+            (input, target, weight),
             input,
             target,
             weight=weight,
@@ -2887,10 +2887,10 @@ def binary_cross_entropy(
         >>> loss = F.binary_cross_entropy(F.sigmoid(input), target)
         >>> loss.backward()
     """
-    if has_torch_function_variadic(input, target):
+    if has_torch_function_variadic(input, target, weight):
         return handle_torch_function(
             binary_cross_entropy,
-            (input, target),
+            (input, target, weight),
             input,
             target,
             weight=weight,
@@ -2959,10 +2959,10 @@ def binary_cross_entropy_with_logits(
          >>> loss = F.binary_cross_entropy_with_logits(input, target)
          >>> loss.backward()
     """
-    if has_torch_function_variadic(input, target):
+    if has_torch_function_variadic(input, target, weight, pos_weight):
         return handle_torch_function(
             binary_cross_entropy_with_logits,
-            (input, target),
+            (input, target, weight, pos_weight),
             input,
             target,
             weight=weight,
@@ -3243,10 +3243,10 @@ def multilabel_soft_margin_loss(
 
     See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details.
     """
-    if has_torch_function_variadic(input, target):
+    if has_torch_function_variadic(input, target, weight):
         return handle_torch_function(
             multilabel_soft_margin_loss,
-            (input, target),
+            (input, target, weight),
             input,
             target,
             weight=weight,
@@ -3323,10 +3323,10 @@ def multi_margin_loss(
 
     See :class:`~torch.nn.MultiMarginLoss` for details.
     """
-    if has_torch_function_variadic(input, target):
+    if has_torch_function_variadic(input, target, weight):
         return handle_torch_function(
             multi_margin_loss,
-            (input, target),
+            (input, target, weight),
             input,
             target,
             p=p,
@@ -4443,8 +4443,8 @@ def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, o
         out (Tensor, optional): the output tensor. If :attr:`out` is used, this
                                 operation won't be differentiable.
     """
-    if has_torch_function_unary(input):
-        return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out)
+    if has_torch_function_variadic(input, out):
+        return handle_torch_function(normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out)
     if out is None:
         denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input)
         return input / denom