Fix AdaptiveLogSoftmaxWithLoss's constructor (#16694)
authorwbydo <wbydo@users.noreply.github.com>
Fri, 15 Feb 2019 14:43:40 +0000 (06:43 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Feb 2019 14:58:00 +0000 (06:58 -0800)
Summary:
t-ken1 and I are members of a same team.
I have added test codes about the pull request https://github.com/pytorch/pytorch/pull/16656.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16694

Differential Revision: D14070106

Pulled By: ezyang

fbshipit-source-id: ff784dbf45e96a6bcf9a4b5cb9544a661a8acad2

test/test_nn.py
torch/nn/modules/adaptive.py

index c3761bc..5b23b86 100644 (file)
@@ -6964,6 +6964,12 @@ class TestNN(NNTestCase):
         with self.assertRaises(ValueError):
             _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 25], div_value=2.)
 
+        with self.assertRaisesRegex(ValueError, "cutoffs should be a sequence of unique,"):
+            _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 20], div_value=2.)
+
+        # not raise
+        _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 19], div_value=2.)
+
         # input shapes
         with self.assertRaisesRegex(RuntimeError, r"Input and target should have the same size"):
             asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.)
index d3cfb7b..595ceae 100644 (file)
@@ -107,7 +107,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
 
         if (cutoffs != sorted(cutoffs)) \
                 or (min(cutoffs) <= 0) \
-                or (max(cutoffs) >= (n_classes - 1)) \
+                or (max(cutoffs) > (n_classes - 1)) \
                 or (len(set(cutoffs)) != len(cutoffs)) \
                 or any([int(c) != c for c in cutoffs]):