From 3bf93d769cdc39a0228b28b06b7742430e282b14 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 9 Sep 2021 08:12:30 -0700 Subject: [PATCH] [JIT] Add gradient check in constants (#64613) Summary: fixes internal issue Pull Request resolved: https://github.com/pytorch/pytorch/pull/64613 Reviewed By: Gamrix Differential Revision: D30799016 Pulled By: eellison fbshipit-source-id: 48ef52d1cac627919e6cd232216d24878a2a8b58 --- test/cpp/jit/test_misc.cpp | 7 +++++++ torch/csrc/jit/ir/constants.cpp | 7 ++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 305d36a..e03920f 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -2685,6 +2685,13 @@ TEST(ComputeFlopsTest, Basic) { ASSERT_EQ(flops, 360); } +TEST(TestConstant, TensorGrad) { + auto graph = std::make_shared(); + IValue ten = torch::randn({3, 5}).requires_grad_(true); + auto con = tryInsertConstant(*graph, ten); + ASSERT_TRUE(con == c10::nullopt); +} + TEST(TestMutation, Basic) { auto graph = std::make_shared(); std::unordered_map vmap; diff --git a/torch/csrc/jit/ir/constants.cpp b/torch/csrc/jit/ir/constants.cpp index 228e42a..7c95d8e 100644 --- a/torch/csrc/jit/ir/constants.cpp +++ b/torch/csrc/jit/ir/constants.cpp @@ -10,7 +10,9 @@ namespace torch { namespace jit { bool insertableTensor(const at::Tensor& ten) { - return !ten.requires_grad(); + // bail if tensor has no storage i.e. opaque tensor used in MKLdnn. + // or gradients because we have no way of serializing them & are mutable + return !ten.requires_grad() && ten.has_storage(); } bool insertableIValue(const IValue& ivalue) { @@ -65,8 +67,7 @@ c10::optional tryInsertConstant( Node* n = g.create(prim::Constant); if (val.isTensor()) { at::Tensor ref = val.toTensor(); - if (!ref.has_storage()) { - // bail if tensor has no storage i.e. opaque tensor used in MKLdnn. + if (!insertableTensor(val.toTensor())) { n->destroy(); return c10::nullopt; } -- 2.7.4