Const trace error v2 (#18535)
authorElias Ellison <eellison@fb.com>
Wed, 27 Mar 2019 21:28:11 +0000 (14:28 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 27 Mar 2019 21:40:56 +0000 (14:40 -0700)
Summary:
Trying to reland https://github.com/pytorch/pytorch/pull/18298
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18535

Differential Revision: D14652391

Pulled By: eellison

fbshipit-source-id: 699e30045dd5f14f0a2b98378272045a292e1e2a

test/test_jit.py
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/tracer.cpp

index 6ef2c1f..94c3e89 100644 (file)
@@ -7267,6 +7267,21 @@ a")
 
             self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
 
+    def test_tensor_with_grad_as_constant(self):
+        param = torch.randn(3).requires_grad_()
+        x = torch.randn(3)
+
+        def f(x):
+            return x + param
+        with self.assertRaisesRegex(RuntimeError, "Cannot insert a Tensor that requires grad as a constant"):
+            torch.jit.trace(f, x)
+
+    def test_non_tensor_tracing(self):
+        def f(x):
+            return x + param
+        with self.assertRaisesRegex(RuntimeError, "inputs or outputs of traced functions, but instead got value of type int."):
+            torch.jit.trace(f, (1,))
+
     def test_type_annotation_module(self):
         class BaseModule(torch.jit.ScriptModule):
             def foo(self, x):
index 1e0f5fd..dedc639 100644 (file)
@@ -92,9 +92,14 @@ inline IValue toIValue(py::handle input) {
     }
     return Tuple::create(s);
   } else {
-    AT_ERROR(
-        "Only tensors and (possibly nested) tuples of tensors are supported "
-        "as inputs or outputs of traced functions");
+    throw std::runtime_error(c10::str(
+        "Only tensors and (possibly nested) tuples of tensors are supported ",
+        "as inputs or outputs of traced functions",
+        ", but instead got value of type ",
+        py::str(input.get_type().attr("__name__")),
+        ".",
+        "\nValue: ",
+        py::repr(input)));
   }
 }
 
index 7855b75..38f7aa4 100644 (file)
@@ -102,6 +102,13 @@ Value* getValueTrace(const IValue& var) {
     }
 
     // Didn't find it. Bake in a constant
+    if (ten.is_variable() && ten.requires_grad()) {
+      std::ostringstream oss;
+      oss << "Cannot insert a Tensor that requires grad as a constant. "
+          << "Consider making it a parameter or input, or detaching the gradient\n";
+      throw std::runtime_error(oss.str());
+    }
+
     Value* constant = state->graph->insertConstant(ten);
     recordSourceLocation(constant->node());
     constant->inferTypeFrom(ten);
@@ -229,7 +236,8 @@ std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
       return Tuple::create(std::move(elems));
     } else {
       AT_ERROR(
-          "Only tensors or tuples of tensors can be inputs to traced functions");
+          "Only tensors or tuples of tensors can be inputs to traced functions. Got ",
+          type);
     }
   };
   for (IValue& input : inputs) {
@@ -429,10 +437,7 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
 void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
   AT_ERROR("Tracing float lists currently not supported!");
 }
-void addInputs(
-    Node* n,
-    const char* name,
-    const std::vector<double>& value) {
+void addInputs(Node* n, const char* name, const std::vector<double>& value) {
   AT_ERROR("Tracing float lists currently not supported!");
 }