Revert D14584266: [pytorch][PR] Better error message for tensor with grad as constant...
authorMichael Suo <suo@fb.com>
Sat, 23 Mar 2019 09:47:57 +0000 (02:47 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 23 Mar 2019 09:50:54 +0000 (02:50 -0700)
Differential Revision:
D14584266

Original commit changeset: 4e7850dadc78

fbshipit-source-id: 3bb3b5006e469edff984c16e0ff8d5dac2862d88

test/test_jit.py
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/tracer.cpp

index dc10320..8145c88 100644 (file)
@@ -7201,21 +7201,6 @@ a")
 
             self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
 
-    def test_tensor_with_grad_as_constant(self):
-        param = torch.randn(3).requires_grad_()
-        x = torch.randn(3)
-
-        def f(x):
-            return x + param
-        with self.assertRaisesRegex(RuntimeError, "Cannot insert a Tensor that requires grad as a constant"):
-            torch.jit.trace(f, x)
-
-    def test_non_tensor_tracing(self):
-        def f(x):
-            return x + param
-        with self.assertRaisesRegex(RuntimeError, "inputs or outputs of traced functions, but instead got value of type int."):
-            torch.jit.trace(f, (1,))
-
     def test_type_annotation_module(self):
         class BaseModule(torch.jit.ScriptModule):
             def foo(self, x):
index dedc639..1e0f5fd 100644 (file)
@@ -92,14 +92,9 @@ inline IValue toIValue(py::handle input) {
     }
     return Tuple::create(s);
   } else {
-    throw std::runtime_error(c10::str(
-        "Only tensors and (possibly nested) tuples of tensors are supported ",
-        "as inputs or outputs of traced functions",
-        ", but instead got value of type ",
-        py::str(input.get_type().attr("__name__")),
-        ".",
-        "\nValue: ",
-        py::repr(input)));
+    AT_ERROR(
+        "Only tensors and (possibly nested) tuples of tensors are supported "
+        "as inputs or outputs of traced functions");
   }
 }
 
index e195883..7855b75 100644 (file)
@@ -102,13 +102,6 @@ Value* getValueTrace(const IValue& var) {
     }
 
     // Didn't find it. Bake in a constant
-    if (ten.requires_grad()) {
-      std::ostringstream oss;
-      oss << "Cannot insert a Tensor that requires grad as a constant. "
-          << "Consider making it a parameter or input, or detaching the gradient\n";
-      throw std::runtime_error(oss.str());
-    }
-
     Value* constant = state->graph->insertConstant(ten);
     recordSourceLocation(constant->node());
     constant->inferTypeFrom(ten);
@@ -236,8 +229,7 @@ std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
       return Tuple::create(std::move(elems));
     } else {
       AT_ERROR(
-          "Only tensors or tuples of tensors can be inputs to traced functions. Got ",
-          type);
+          "Only tensors or tuples of tensors can be inputs to traced functions");
     }
   };
   for (IValue& input : inputs) {
@@ -437,7 +429,10 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
 void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
   AT_ERROR("Tracing float lists currently not supported!");
 }
-void addInputs(Node* n, const char* name, const std::vector<double>& value) {
+void addInputs(
+    Node* n,
+    const char* name,
+    const std::vector<double>& value) {
   AT_ERROR("Tracing float lists currently not supported!");
 }