[JIT] Add gradient check in constants (#64613)
authorElias Ellison <eellison@devfair044.h1.fair>
Thu, 9 Sep 2021 15:12:30 +0000 (08:12 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 15:13:57 +0000 (08:13 -0700)
Summary:
fixes internal issue

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64613

Reviewed By: Gamrix

Differential Revision: D30799016

Pulled By: eellison

fbshipit-source-id: 48ef52d1cac627919e6cd232216d24878a2a8b58

test/cpp/jit/test_misc.cpp
torch/csrc/jit/ir/constants.cpp

index 305d36a..e03920f 100644 (file)
@@ -2685,6 +2685,13 @@ TEST(ComputeFlopsTest, Basic) {
   ASSERT_EQ(flops, 360);
 }
 
+TEST(TestConstant, TensorGrad) {
+  auto graph = std::make_shared<Graph>();
+  IValue ten = torch::randn({3, 5}).requires_grad_(true);
+  auto con = tryInsertConstant(*graph, ten);
+  ASSERT_TRUE(con == c10::nullopt);
+}
+
 TEST(TestMutation, Basic) {
   auto graph = std::make_shared<Graph>();
   std::unordered_map<std::string, Value*> vmap;
index 228e42a..7c95d8e 100644 (file)
@@ -10,7 +10,9 @@ namespace torch {
 namespace jit {
 
 bool insertableTensor(const at::Tensor& ten) {
-  return !ten.requires_grad();
+  // bail if tensor has no storage i.e. opaque tensor used in MKLdnn.
+  // or gradients because we have no way of serializing them & are mutable
+  return !ten.requires_grad() && ten.has_storage();
 }
 
 bool insertableIValue(const IValue& ivalue) {
@@ -65,8 +67,7 @@ c10::optional<Value*> tryInsertConstant(
   Node* n = g.create(prim::Constant);
   if (val.isTensor()) {
     at::Tensor ref = val.toTensor();
-    if (!ref.has_storage()) {
-      // bail if tensor has no storage i.e. opaque tensor used in MKLdnn.
+    if (!insertableTensor(val.toTensor())) {
       n->destroy();
       return c10::nullopt;
     }