self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
- def test_tensor_with_grad_as_constant(self):
- param = torch.randn(3).requires_grad_()
- x = torch.randn(3)
-
- def f(x):
- return x + param
- with self.assertRaisesRegex(RuntimeError, "Cannot insert a Tensor that requires grad as a constant"):
- torch.jit.trace(f, x)
-
- def test_non_tensor_tracing(self):
- def f(x):
- return x + param
- with self.assertRaisesRegex(RuntimeError, "inputs or outputs of traced functions, but instead got value of type int."):
- torch.jit.trace(f, (1,))
-
def test_type_annotation_module(self):
class BaseModule(torch.jit.ScriptModule):
def foo(self, x):
}
return Tuple::create(s);
} else {
- throw std::runtime_error(c10::str(
- "Only tensors and (possibly nested) tuples of tensors are supported ",
- "as inputs or outputs of traced functions",
- ", but instead got value of type ",
- py::str(input.get_type().attr("__name__")),
- ".",
- "\nValue: ",
- py::repr(input)));
+ AT_ERROR(
+ "Only tensors and (possibly nested) tuples of tensors are supported "
+ "as inputs or outputs of traced functions");
}
}
}
// Didn't find it. Bake in a constant
- if (ten.requires_grad()) {
- std::ostringstream oss;
- oss << "Cannot insert a Tensor that requires grad as a constant. "
- << "Consider making it a parameter or input, or detaching the gradient\n";
- throw std::runtime_error(oss.str());
- }
-
Value* constant = state->graph->insertConstant(ten);
recordSourceLocation(constant->node());
constant->inferTypeFrom(ten);
return Tuple::create(std::move(elems));
} else {
AT_ERROR(
- "Only tensors or tuples of tensors can be inputs to traced functions. Got ",
- type);
+ "Only tensors or tuples of tensors can be inputs to traced functions");
}
};
for (IValue& input : inputs) {
void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
AT_ERROR("Tracing float lists currently not supported!");
}
-void addInputs(Node* n, const char* name, const std::vector<double>& value) {
+void addInputs(
+ Node* n,
+ const char* name,
+ const std::vector<double>& value) {
AT_ERROR("Tracing float lists currently not supported!");
}