self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
+ def test_tensor_with_grad_as_constant(self):
+ param = torch.randn(3).requires_grad_()
+ x = torch.randn(3)
+
+ def f(x):
+ return x + param
+ with self.assertRaisesRegex(RuntimeError, "Cannot insert a Tensor that requires grad as a constant"):
+ torch.jit.trace(f, x)
+
+ def test_non_tensor_tracing(self):
+ def f(x):
+ return x + param
+ with self.assertRaisesRegex(RuntimeError, "inputs or outputs of traced functions, but instead got value of type int."):
+ torch.jit.trace(f, (1,))
+
def test_type_annotation_module(self):
class BaseModule(torch.jit.ScriptModule):
def foo(self, x):
}
return Tuple::create(s);
} else {
- AT_ERROR(
- "Only tensors and (possibly nested) tuples of tensors are supported "
- "as inputs or outputs of traced functions");
+ throw std::runtime_error(c10::str(
+ "Only tensors and (possibly nested) tuples of tensors are supported ",
+ "as inputs or outputs of traced functions",
+ ", but instead got value of type ",
+ py::str(input.get_type().attr("__name__")),
+ ".",
+ "\nValue: ",
+ py::repr(input)));
}
}
}
// Didn't find it. Bake in a constant
+ if (ten.is_variable() && ten.requires_grad()) {
+ std::ostringstream oss;
+ oss << "Cannot insert a Tensor that requires grad as a constant. "
+ << "Consider making it a parameter or input, or detaching the gradient\n";
+ throw std::runtime_error(oss.str());
+ }
+
Value* constant = state->graph->insertConstant(ten);
recordSourceLocation(constant->node());
constant->inferTypeFrom(ten);
return Tuple::create(std::move(elems));
} else {
AT_ERROR(
- "Only tensors or tuples of tensors can be inputs to traced functions");
+ "Only tensors or tuples of tensors can be inputs to traced functions. Got ",
+ type);
}
};
for (IValue& input : inputs) {
void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
AT_ERROR("Tracing float lists currently not supported!");
}
-void addInputs(
- Node* n,
- const char* name,
- const std::vector<double>& value) {
+void addInputs(Node* n, const char* name, const std::vector<double>& value) {
AT_ERROR("Tracing float lists currently not supported!");
}