Fix default dtype in shape analysis (#18968)
authorElias Ellison <eellison@fb.com>
Mon, 8 Apr 2019 21:44:45 +0000 (14:44 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 8 Apr 2019 21:50:28 +0000 (14:50 -0700)
Summary:
Fix for https://github.com/pytorch/pytorch/issues/18823

Previously we were setting the dtype to Float when in torchscript the default is double. When the problem in https://github.com/pytorch/pytorch/issues/17662 gets landed, we will have to reevalute (and this test will fail).

We should still be consistent in shape_analysis in the meantime.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18968

Differential Revision: D14837939

Pulled By: eellison

fbshipit-source-id: 32383b55c14bdc7753e26dec33c39ab10124c255

test/test_jit.py
torch/csrc/jit/passes/shape_analysis.cpp

index 9bc2b96..f6652fb 100644 (file)
@@ -8328,6 +8328,21 @@ a")
             return a + 1.0 - a
 
         self.checkScript(test_rand, ())
+        fn = torch.jit.script(test_rand)
+        out = fn()
+        self.assertEqual(out.dtype, torch.double)
+        g = fn.graph_for()
+        # Testing shape analysis correctly setting type
+        FileCheck().check("Double(*, *)").check_not("Float(*, *)").run(g)
+
+        @torch.jit.script
+        def randint():
+            return torch.randint(0, 5, [1, 2])
+        out = randint()
+        self.assertEqual(out.dtype, torch.double)
+        # although the type should be int here, testing that the runtime dtype
+        # and shape analysis dtype is the same.
+        FileCheck().check("Double(*, *)").check_not("Float(*, *)").run(randint.graph_for())
 
     def test_erase_number_types(self):
         def func(a):
index cef7e79..75ff7fa 100644 (file)
@@ -782,7 +782,8 @@ class ShapePropagator {
             auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
             auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
             size_t arg_for_type = 0;
-            if (c10::promoteTypes(first_scalar_type, second_scalar_type) != first_scalar_type) {
+            if (c10::promoteTypes(first_scalar_type, second_scalar_type) !=
+                first_scalar_type) {
               arg_for_type = 1;
             }
             return {broadcast(*maybe_tensor_types, arg_for_type)};
@@ -1180,7 +1181,7 @@ class ShapePropagator {
       if (!maybe_dtype_option)
         return {};
       auto dtype =
-          (maybe_dtype_option->isNone() ? at::kFloat
+          (maybe_dtype_option->isNone() ? at::kDouble
                                         : maybe_dtype_option->toScalarType());
 
       return {DimensionedTensorType::create(dtype, device, dim)};