#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/alias_info.h>
+#include <unordered_map>
namespace c10 {
// arguments are not checked by schema
const bool is_vararg_;
const bool is_varret_;
+ void checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const;
public:
const std::string& name() const {
is_vararg(),
is_varret());
}
+
// Check that inputs have the correct types and appends any missing default
// values.
- void checkAndNormalizeInputs(std::vector<IValue>& inputs) const;
+ void checkAndNormalizeInputs(
+ std::vector<IValue>& inputs,
+ const std::unordered_map<std::string, IValue>& kwargs) const;
+
+ void findErrorInKwargs(const std::vector<std::string>& kwargs) const;
};
inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
return out << arg.type()->str() << " " << arg.name() << (arg.default_value() ? "=<default>" : "");
}
-inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
- // eventually this should look almost identical to python arg parser, but
- // it is simpler for now to work directly on this schema
-
- out << schema.name();
- out << "(";
-
- bool seen_kwarg_only = false;
- for(size_t i = 0; i < schema.arguments().size(); ++i) {
- if (i > 0) out << ", ";
- if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
- out << "*, ";
- seen_kwarg_only = true;
- }
- out << schema.arguments()[i];
- }
-
- if(schema.is_vararg()) {
- if(schema.arguments().size() > 0)
- out << ", ";
- out << "...";
- }
-
- out << ") -> ";
- if (schema.returns().size() == 1) {
- out << schema.returns().at(0).type()->str();
- } else if (schema.returns().size() > 1) {
- out << "(";
- for (size_t i = 0; i < schema.returns().size(); ++i) {
- if (i > 0) out << ", ";
- out << schema.returns()[i].type()->str();
- }
- out << ")";
- }
- return out;
-}
+inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema);
inline std::string toString(const FunctionSchema& schema) {
std::ostringstream str;
return str.str();
}
-inline void FunctionSchema::checkAndNormalizeInputs(std::vector<IValue>& inputs) const {
- // Do we have more inputs than the schema accepts?
- AT_CHECK(
- inputs.size() <= arguments().size(),
- "Expected at most ",
- arguments().size(),
- " argument(s) for operator '",
- name(),
- "', but received ",
- inputs.size(),
- " argument(s). Declaration: ",
- *this);
-
- for (size_t pos = 0; pos < arguments().size(); ++pos) {
- const auto& argument = arguments()[pos];
- if (pos < inputs.size()) {
- if (!isSubvalueOf(inputs[pos], argument.type())) {
- AT_ERROR(
- "Expected value of type ",
- *argument.type(),
- " for argument '",
- argument.name(),
- "' in position ",
- pos,
- ", but instead got value of type ",
- attemptToRecoverType(inputs[pos])->str(),
- ". Declaration: ",
- *this);
- }
- } else if (argument.default_value()) {
- inputs.push_back(*argument.default_value());
- } else {
- AT_ERROR(
- name(),
- "() is missing value for argument '",
- argument.name(),
- "'. Declaration: ",
- *this);
- }
- }
-}
-
} // namespace c10
+
+#include <ATen/core/function_schema_inl.h>
--- /dev/null
+#pragma once
+
+// note: windows build doesn't find symbols in operator files unless
+// this is a header file
+
+namespace c10 {
+
+inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
+ // eventually this should look almost identical to python arg parser, but
+ // it is simpler for now to work directly on this schema
+
+ out << schema.name();
+ out << "(";
+
+ bool seen_kwarg_only = false;
+ for(size_t i = 0; i < schema.arguments().size(); ++i) {
+ if (i > 0) out << ", ";
+ if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
+ out << "*, ";
+ seen_kwarg_only = true;
+ }
+ out << schema.arguments()[i];
+ }
+
+ if(schema.is_vararg()) {
+ if(schema.arguments().size() > 0)
+ out << ", ";
+ out << "...";
+ }
+
+ out << ") -> ";
+ if (schema.returns().size() == 1) {
+ out << schema.returns().at(0).type()->str();
+ } else if (schema.returns().size() > 1) {
+ out << "(";
+ for (size_t i = 0; i < schema.returns().size(); ++i) {
+ if (i > 0) out << ", ";
+ out << schema.returns()[i].type()->str();
+ }
+ out << ")";
+ }
+ return out;
+}
+
+inline void FunctionSchema::checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const {
+ if (!isSubvalueOf(value, argument.type())) {
+ std::string position = pos ? ::c10::str(" in position ", *pos) : "";
+ AT_ERROR(
+ "Expected value of type ",
+ *argument.type(),
+ " for argument '",
+ argument.name(),
+ "'",
+ position,
+ ", but instead got value of type ",
+ attemptToRecoverType(value)->str(),
+ ". Declaration: ",
+ *this);
+ }
+}
+
+inline void FunctionSchema::findErrorInKwargs(const std::vector<std::string>& kwargs) const {
+ // First check if any of the kwargs are unknown, i.e. don't match the name of
+ // any argument in the schema.
+ for (const auto& kwarg : kwargs) {
+ if (!std::count_if(
+ arguments().begin(),
+ arguments().end(),
+ [&kwarg](const Argument& argument) {
+ return argument.name() == kwarg;
+ })) {
+ throw std::runtime_error(c10::str(
+ "Unknown keyword argument '",
+ kwarg,
+ "' for operator '",
+ name(),
+ "'. Schema: ",
+ *this));
+ }
+ }
+ // If there are unconsumed kwargs but none of them were unknown, the first
+ // positional argument present in the kwargs is duplicated.
+ for (const auto& argument : arguments()) {
+ if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) {
+ AT_ASSERT(!argument.default_value());
+ throw std::runtime_error(c10::str(
+ "Argument '",
+ argument.name(),
+ "' specified both as positional and ",
+ "keyword argument. Schema: ",
+ *this));
+ }
+ }
+}
+
+inline void FunctionSchema::checkAndNormalizeInputs(
+ std::vector<IValue>& inputs,
+ const std::unordered_map<std::string, IValue>& kwargs) const {
+ // Do we have more inputs than the schema accepts?
+ AT_CHECK(
+ inputs.size() <= arguments().size(),
+ "Expected at most ",
+ arguments().size(),
+ " argument(s) for operator '",
+ name(),
+ "', but received ",
+ inputs.size(),
+ " argument(s). Declaration: ",
+ *this);
+
+ size_t consumed_kwargs = 0;
+ for (size_t pos = 0; pos < arguments().size(); ++pos) {
+ const auto& argument = arguments()[pos];
+ if (pos < inputs.size()) {
+ checkArg(inputs[pos], argument, pos);
+ continue;
+ }
+ auto it = kwargs.find(argument.name());
+ if (it != kwargs.end()) {
+ checkArg(it->second, argument, nullopt);
+ inputs.push_back(it->second);
+ consumed_kwargs++;
+ continue;
+ }
+ if (argument.default_value()) {
+ inputs.push_back(*argument.default_value());
+ continue;
+ }
+ AT_ERROR(
+ name(),
+ "() is missing value for argument '",
+ argument.name(),
+ "'. Declaration: ",
+ *this);
+ }
+ if (consumed_kwargs != kwargs.size()) {
+ std::vector<std::string> names;
+ for(const auto& k : kwargs) {
+ names.emplace_back(k.first);
+ }
+ findErrorInKwargs(names);
+ }
+}
+
+} // namespace c10
_(CreateAutodiffSubgraphs) \
_(CustomOperators) \
_(CustomOperatorAliasing) \
+ _(IValueKWargs) \
_(Differentiate) \
_(DifferentiateWithRequiresGrad) \
_(DynamicDAG) \
#include "torch/csrc/jit/irparser.h"
#include "torch/csrc/jit/passes/alias_analysis.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "torch/jit.h"
namespace torch {
namespace jit {
testing::FileCheck().run(text, *graph);
}
}
+
+void testIValueKWargs() {
+ const auto text = R"(
+ def foo(a : int, b : int, c : int = 4):
+ return a + 2*b + 3*c
+ )";
+ auto cu = compile(text);
+ auto result = cu->get_function("foo")({1}, {{"b", 3}});
+ ASSERT_EQ(result.toInt(), 19);
+}
+
} // namespace test
} // namespace jit
} // namespace torch
namespace torch {
namespace jit {
-namespace detail {
-
-using ::c10::Argument;
-using ::c10::FunctionSchema;
// error reporting: when reporting user-caused errors, these functions should
// not use AT_ERROR macros, since these macros add stack trace information
// that is confusing to display to the end user since it always reports
// locations in libtorch code rather than user code.
-inline void findErrorInKwargs(const FunctionSchema& schema, py::kwargs kwargs) {
- const auto& arguments = schema.arguments();
- // First check if any of the kwargs are unknown, i.e. don't match the name of
- // any argument in the schema.
- for (const auto& kwarg : kwargs) {
- const auto key = py::cast<std::string>(kwarg.first);
- if (!std::count_if(
- arguments.begin(),
- arguments.end(),
- [&key](const Argument& argument) {
- return argument.name() == key;
- })) {
- throw std::runtime_error(c10::str(
- "Unknown keyword argument '",
- key,
- "' for operator '",
- schema.name(),
- "'. Schema: ",
- schema));
- }
- }
- // If there are unconsumed kwargs but none of them were unknown, the first
- // positional argument present in the kwargs is duplicated.
- for (const auto& argument : arguments) {
- if (kwargs.contains(argument.name().c_str())) {
- AT_ASSERT(!argument.default_value());
- throw std::runtime_error(c10::str(
- "Argument '",
- argument.name(),
- "' specified both as positional and ",
- "keyword argument. Schema: ",
- schema));
- }
- }
-}
-} // namespace detail
-
inline IValue toIValue(py::handle input) {
if (THPVariable_Check(input.ptr())) {
auto ten = py::cast<at::Tensor>(input);
const auto obj = ivalue.toObject();
const auto classType = ClassType::get(obj->name());
AT_ASSERT(classType);
- auto pyClass = py::module::import("torch.jit")
- .attr("_get_script_class")(obj->name());
+ auto pyClass =
+ py::module::import("torch.jit").attr("_get_script_class")(obj->name());
auto pyObj = pyClass.attr("__new__")(pyClass);
-
const auto numAttrs = classType->numAttributes();
for (size_t slot = 0; slot < numAttrs; slot++) {
}
if (consumed_kwargs != kwargs.size()) {
- detail::findErrorInKwargs(schema, kwargs);
+ std::vector<std::string> names;
+ for (const auto& kwarg : kwargs) {
+ names.emplace_back(py::cast<std::string>(kwarg.first));
+ }
+ schema.findErrorInKwargs(names);
}
return stack;
struct Def;
struct SugaredValue;
struct Function;
+using Kwargs = std::unordered_map<std::string, IValue>;
using Resolver = std::function<std::shared_ptr<SugaredValue>(
const std::string& name,
run(stack);
}
- IValue operator()(std::vector<IValue> stack) {
- getSchema().checkAndNormalizeInputs(stack);
+ IValue operator()(
+ std::vector<IValue> stack,
+ const Kwargs& kwargs = Kwargs()) {
+ getSchema().checkAndNormalizeInputs(stack, kwargs);
run(stack);
return stack.front();
}
mutable std::unique_ptr<FunctionSchema> schema_;
};
-
// A CompilationUnit is a list of named script::Functions
// with helper methods to iterate the list, or invoke the function.
// Classes have a CompilationUnit holding the class methods
run(stack);
}
- IValue operator()(std::vector<IValue> stack) {
- getSchema().checkAndNormalizeInputs(stack);
+ IValue operator()(
+ std::vector<IValue> stack,
+ const Kwargs& kwargs = Kwargs()) {
+ getSchema().checkAndNormalizeInputs(stack, kwargs);
for (auto input : initial_ivalues_) {
push(stack, input.value());
}