From 930fb2f31986a716f9588aad4dd372c28550d7cb Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Mon, 8 Apr 2019 14:44:45 -0700 Subject: [PATCH] Fix default dtype in shape analysis (#18968) Summary: Fix for https://github.com/pytorch/pytorch/issues/18823 Previously we were setting the dtype to Float when in torchscript the default is double. When the problem in https://github.com/pytorch/pytorch/issues/17662 gets landed, we will have to reevalute (and this test will fail). We should still be consistent in shape_analysis in the meantime. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18968 Differential Revision: D14837939 Pulled By: eellison fbshipit-source-id: 32383b55c14bdc7753e26dec33c39ab10124c255 --- test/test_jit.py | 15 +++++++++++++++ torch/csrc/jit/passes/shape_analysis.cpp | 5 +++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 9bc2b96..f6652fb 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8328,6 +8328,21 @@ a") return a + 1.0 - a self.checkScript(test_rand, ()) + fn = torch.jit.script(test_rand) + out = fn() + self.assertEqual(out.dtype, torch.double) + g = fn.graph_for() + # Testing shape analysis correctly setting type + FileCheck().check("Double(*, *)").check_not("Float(*, *)").run(g) + + @torch.jit.script + def randint(): + return torch.randint(0, 5, [1, 2]) + out = randint() + self.assertEqual(out.dtype, torch.double) + # although the type should be int here, testing that the runtime dtype + # and shape analysis dtype is the same. + FileCheck().check("Double(*, *)").check_not("Float(*, *)").run(randint.graph_for()) def test_erase_number_types(self): def func(a): diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index cef7e79..75ff7fa 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -782,7 +782,8 @@ class ShapePropagator { auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType(); auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType(); size_t arg_for_type = 0; - if (c10::promoteTypes(first_scalar_type, second_scalar_type) != first_scalar_type) { + if (c10::promoteTypes(first_scalar_type, second_scalar_type) != + first_scalar_type) { arg_for_type = 1; } return {broadcast(*maybe_tensor_types, arg_for_type)}; @@ -1180,7 +1181,7 @@ class ShapePropagator { if (!maybe_dtype_option) return {}; auto dtype = - (maybe_dtype_option->isNone() ? at::kFloat + (maybe_dtype_option->isNone() ? at::kDouble : maybe_dtype_option->toScalarType()); return {DimensionedTensorType::create(dtype, device, dim)}; -- 2.7.4