From d85451c07b4321127efb301ad4ddd6ba44325342 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Sat, 23 Mar 2019 02:47:57 -0700 Subject: [PATCH] Revert D14584266: [pytorch][PR] Better error message for tensor with grad as constant in tracing Differential Revision: D14584266 Original commit changeset: 4e7850dadc78 fbshipit-source-id: 3bb3b5006e469edff984c16e0ff8d5dac2862d88 --- test/test_jit.py | 15 --------------- torch/csrc/jit/pybind_utils.h | 11 +++-------- torch/csrc/jit/tracer.cpp | 15 +++++---------- 3 files changed, 8 insertions(+), 33 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index dc10320..8145c88 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -7201,21 +7201,6 @@ a") self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True) - def test_tensor_with_grad_as_constant(self): - param = torch.randn(3).requires_grad_() - x = torch.randn(3) - - def f(x): - return x + param - with self.assertRaisesRegex(RuntimeError, "Cannot insert a Tensor that requires grad as a constant"): - torch.jit.trace(f, x) - - def test_non_tensor_tracing(self): - def f(x): - return x + param - with self.assertRaisesRegex(RuntimeError, "inputs or outputs of traced functions, but instead got value of type int."): - torch.jit.trace(f, (1,)) - def test_type_annotation_module(self): class BaseModule(torch.jit.ScriptModule): def foo(self, x): diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index dedc639..1e0f5fd 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -92,14 +92,9 @@ inline IValue toIValue(py::handle input) { } return Tuple::create(s); } else { - throw std::runtime_error(c10::str( - "Only tensors and (possibly nested) tuples of tensors are supported ", - "as inputs or outputs of traced functions", - ", but instead got value of type ", - py::str(input.get_type().attr("__name__")), - ".", - "\nValue: ", - py::repr(input))); + AT_ERROR( + "Only tensors and (possibly nested) tuples of tensors are supported " + "as inputs or outputs of traced functions"); } } diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index e195883..7855b75 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -102,13 +102,6 @@ Value* getValueTrace(const IValue& var) { } // Didn't find it. Bake in a constant - if (ten.requires_grad()) { - std::ostringstream oss; - oss << "Cannot insert a Tensor that requires grad as a constant. " - << "Consider making it a parameter or input, or detaching the gradient\n"; - throw std::runtime_error(oss.str()); - } - Value* constant = state->graph->insertConstant(ten); recordSourceLocation(constant->node()); constant->inferTypeFrom(ten); @@ -236,8 +229,7 @@ std::pair, Stack> enter(Stack inputs) { return Tuple::create(std::move(elems)); } else { AT_ERROR( - "Only tensors or tuples of tensors can be inputs to traced functions. Got ", - type); + "Only tensors or tuples of tensors can be inputs to traced functions"); } }; for (IValue& input : inputs) { @@ -437,7 +429,10 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) { void addInputs(Node* n, const char* name, const ArrayRef& value) { AT_ERROR("Tracing float lists currently not supported!"); } -void addInputs(Node* n, const char* name, const std::vector& value) { +void addInputs( + Node* n, + const char* name, + const std::vector& value) { AT_ERROR("Tracing float lists currently not supported!"); } -- 2.7.4