Allow undefined tensors as constants (#14120)
authorDavid Riazati <davidriazati@fb.com>
Wed, 21 Nov 2018 00:42:00 +0000 (16:42 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 21 Nov 2018 00:54:27 +0000 (16:54 -0800)
Summary:
This PR inserts `prim::None` constants for undefined tensors. This comes in the standard library if an `Optional[Tensor]` is statically determined to be `None`:

```python
torch.jit.script
def fn(x=None):
    # type: (Optional[Tensor]) -> Tensor
    return torch.jit._unwrap_optional(x)

torch.jit.script
def fn2():
    # type: () -> Tensor
    return fn()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14120

Differential Revision: D13124625

Pulled By: driazati

fbshipit-source-id: 9eaa82e478c49c503f68ed89d8c770e8273ea569

aten/src/ATen/core/ivalue.h
test/test_jit.py
torch/csrc/jit/constants.cpp
torch/nn/functional.py

index 5293cd5..2245042 100644 (file)
@@ -343,6 +343,7 @@ struct CAFFE2_API IValue final {
     return Tag::None == tag;
   }
   std::string toNone() const {
+    AT_ASSERT(isNone());
     return "None";
   }
   // Scalar, which gets encoded as either an Int or a Double
index 0531c58..9a0d91b 100644 (file)
@@ -9139,7 +9139,6 @@ EXCLUDE_SCRIPT = {
     'test_nn_binary_cross_entropy',
     'test_nn_binary_cross_entropy_size_average',
     'test_nn_cross_entropy',
-    'test_nn_binary_cross_entropy_with_logits',
     'test_nn_interpolate',
     'test_nn_fold',
     'test_nn_max_unpool1d',
@@ -9148,11 +9147,6 @@ EXCLUDE_SCRIPT = {
     'test_nn_gumbel_softmax',
     'test_nn_poisson_nll_loss',
     'test_nn_poisson_nll_loss_full',
-
-    # undefined tensors as constants
-    'test_nn_instance_norm',
-    'test_nn_normalize',
-    'test_nn_multilabel_soft_margin_loss',
 }
 
 DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
index d8b41e1..b3f11ee 100644 (file)
@@ -15,7 +15,7 @@ Value* insertConstant(
   if(val.isTensor()) {
     at::Tensor ref = std::move(val).toTensor();
     if(!ref.defined()) {
-      throw constant_not_supported_error("undefined tensors cannot become constants");
+      return insertConstant(g, val, loc, scope);
     }
     if (ref.is_variable()) {
       ref = autograd::Variable(ref).data();
@@ -69,7 +69,7 @@ RegisterOperators reg({
         if(type->isSubtypeOf(DynamicType::get())) {
           auto t = autograd::make_variable(node->t(attr::value));
           return [t](Stack& stack) {
-            stack.push_back(t);
+            push(stack, t);
             return 0;
           };
         } else if (type->isSubtypeOf(BoolType::get())) {
index cac9a38..ac58b7c 100644 (file)
@@ -1443,6 +1443,7 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None,
     )
 
 
+@torch._jit_internal.weak_script
 def instance_norm(input, running_mean=None, running_var=None, weight=None,
                   bias=None, use_input_stats=True, momentum=0.1, eps=1e-5):
     # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor  # noqa
@@ -1458,7 +1459,9 @@ def instance_norm(input, running_mean=None, running_var=None, weight=None,
     )
 
 
+@torch._jit_internal.weak_script
 def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
+    # type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor
     r"""Applies Layer Normalization for last certain number of dimensions.
 
     See :class:`~torch.nn.LayerNorm` for details.
@@ -1467,7 +1470,9 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
                             torch.backends.cudnn.enabled)
 
 
+@torch._jit_internal.weak_script
 def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
+    # type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor
     r"""Applies Group Normalization for last certain number of dimensions.
 
     See :class:`~torch.nn.GroupNorm` for details.
@@ -1812,6 +1817,7 @@ def binary_cross_entropy(input, target, weight=None, size_average=None,
         input, target, weight, reduction_enum)
 
 
+@torch._jit_internal.weak_script
 def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
                                      reduce=None, reduction='mean', pos_weight=None):
     # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor
@@ -1916,59 +1922,68 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
     return _pointwise_loss(lambda a, b: (a - b) ** 2, torch._C._nn.mse_loss, input, target, reduction)
 
 
+@torch._jit_internal.weak_script
 def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
                         reduce=None, reduction='mean'):
+    # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
     r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor
 
     See :class:`~torch.nn.MarginRankingLoss` for details.
     """  # noqa
     if size_average is not None or reduce is not None:
-        reduction = _Reduction.legacy_get_enum(size_average, reduce)
+        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
     else:
-        reduction = _Reduction.get_enum(reduction)
+        reduction_enum = _Reduction.get_enum(reduction)
     if input1.dim() == 0 or input2.dim() == 0 or target.dim() == 0:
         raise RuntimeError(("margin_ranking_loss does not support scalars, got sizes: "
                             "input1: {}, input2: {}, target: {} ".format(input1.size(), input2.size(), target.size())))
-    return torch.margin_ranking_loss(input1, input2, target, margin, reduction)
+    return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
 
 
+@torch._jit_internal.weak_script
 def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
                          reduce=None, reduction='mean'):
+    # type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
     r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor
 
     See :class:`~torch.nn.HingeEmbeddingLoss` for details.
     """  # noqa
     if size_average is not None or reduce is not None:
-        reduction = _Reduction.legacy_get_enum(size_average, reduce)
+        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
     else:
-        reduction = _Reduction.get_enum(reduction)
-    return torch.hinge_embedding_loss(input, target, margin, reduction)
+        reduction_enum = _Reduction.get_enum(reduction)
+    return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
 
 
+@torch._jit_internal.weak_script
 def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
+    # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
     r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
 
     See :class:`~torch.nn.MultiLabelMarginLoss` for details.
     """
     if size_average is not None or reduce is not None:
-        reduction = _Reduction.legacy_get_enum(size_average, reduce)
+        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
     else:
-        reduction = _Reduction.get_enum(reduction)
-    return torch._C._nn.multilabel_margin_loss(input, target, reduction)
+        reduction_enum = _Reduction.get_enum(reduction)
+    return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
 
 
+@torch._jit_internal.weak_script
 def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
+    # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
     r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
 
     See :class:`~torch.nn.SoftMarginLoss` for details.
     """
     if size_average is not None or reduce is not None:
-        reduction = _Reduction.legacy_get_enum(size_average, reduce)
+        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
     else:
-        reduction = _Reduction.get_enum(reduction)
-    return torch._C._nn.soft_margin_loss(input, target, reduction)
+        reduction_enum = _Reduction.get_enum(reduction)
+    return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
 
 
+@torch._jit_internal.weak_script
 def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
                                 reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
@@ -1998,17 +2013,19 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
     return ret
 
 
+@torch._jit_internal.weak_script
 def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
                           reduce=None, reduction='mean'):
+    # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
     r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor
 
     See :class:`~torch.nn.CosineEmbeddingLoss` for details.
     """  # noqa
     if size_average is not None or reduce is not None:
-        reduction = _Reduction.legacy_get_enum(size_average, reduce)
+        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
     else:
-        reduction = _Reduction.get_enum(reduction)
-    return torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
+        reduction_enum = _Reduction.get_enum(reduction)
+    return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
 
 
 @torch._jit_internal.weak_script
@@ -2514,19 +2531,22 @@ Example::
 """)
 
 
+@torch._jit_internal.weak_script
 def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
                         reduce=None, reduction="mean"):
+    # type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor
     r"""
     See :class:`~torch.nn.TripletMarginLoss` for details
     """
     if size_average is not None or reduce is not None:
-        reduction = _Reduction.legacy_get_enum(size_average, reduce)
+        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
     else:
-        reduction = _Reduction.get_enum(reduction)
+        reduction_enum = _Reduction.get_enum(reduction)
     return torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps,
-                                     swap, reduction)
+                                     swap, reduction_enum)
 
 
+@torch._jit_internal.weak_script
 def normalize(input, p=2, dim=1, eps=1e-12, out=None):
     # type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
     r"""Performs :math:`L_p` normalization of inputs over specified dimension.