(param
(ident x)
(variable (ident Tensor))
- (option))
+ (option)
+ (False))
(param
(ident y)
(variable (ident Tensor))
- (option))
+ (option)
+ (False))
(param
(ident z)
(variable (ident Tensor))
- (option)))
+ (option)
+ (False)))
(option))
(list
(assign
# type: (Tensor) -> Tensor
return fn2(input, [1])
+ def test_parser_kwargonly(self):
+ cu = torch.jit.CompilationUnit('''
+ def foo(x, *, y) -> Tuple[Tensor, Tensor]:
+ return x, x
+ def bar(x):
+ return foo(x, y=x)
+ ''')
+ self.assertTrue('*' in cu.module._get_method('foo').pretty_print_schema())
+ with self.assertRaisesRegex(RuntimeError, "not provided"):
+ torch.jit.CompilationUnit('''
+ def foo(x, *, y) -> Tuple[Tensor, Tensor]:
+ return x, x
+ def bar(x):
+ return foo(x, x)
+ ''')
+
def test_annoying_doubles(self):
mod = types.ModuleType("temp")
mod.inf = float("inf")
type,
N,
default_value,
- /*kwarg_only =*/false);
+ decl_arg.kwarg_only());
retval.push_back(arg);
}
return retval;
// aten::slice, we should separate it from this function.
if (dim) {
AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
- args.emplace_back(loc, "dim", graph->insertConstant(dim.value(), nullptr, loc));
+ args.emplace_back(
+ loc, "dim", graph->insertConstant(dim.value(), nullptr, loc));
} else {
AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
}
return emitTupleSlice(loc, args[0], args[1], c10::nullopt);
}
}
- NamedValue step = NamedValue(loc, "step", graph->insertConstant(1, nullptr, loc));
+ NamedValue step =
+ NamedValue(loc, "step", graph->insertConstant(1, nullptr, loc));
return emitBuiltinCall(
loc, *graph, aten::slice, c10::nullopt, args, {step}, true);
}
+#include <torch/csrc/jit/script/parser.h>
#include <c10/util/Optional.h>
#include <torch/csrc/jit/script/lexer.h>
#include <torch/csrc/jit/script/parse_string_literal.h>
-#include <torch/csrc/jit/script/parser.h>
#include <torch/csrc/jit/script/tree.h>
#include <torch/csrc/jit/script/tree_views.h>
}
return Expr(prefix);
}
- template <typename T>
- List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
- auto r = L.cur().range;
+ void parseSequence(
+ int begin,
+ int sep,
+ int end,
+ const std::function<void()>& parse) {
if (begin != TK_NOTHING)
L.expect(begin);
- std::vector<T> elements;
if (L.cur().kind != end) {
do {
- elements.push_back((this->*parse)());
+ parse();
} while (L.nextIf(sep));
}
if (end != TK_NOTHING)
L.expect(end);
+ }
+ template <typename T>
+ List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
+ auto r = L.cur().range;
+ std::vector<T> elements;
+ parseSequence(
+ begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
return List<T>::create(r, elements);
}
return Subscript::create(range, Expr(value), subscript_exprs);
}
- TreeRef parseParam() {
+ TreeRef parseParam(bool kwarg_only) {
auto ident = parseIdent();
TreeRef type;
if (L.nextIf(':')) {
def = Maybe<Expr>::create(L.cur().range);
}
return Param::create(
- type->range(), Ident(ident), Expr(type), Maybe<Expr>(def));
+ type->range(), Ident(ident), Expr(type), Maybe<Expr>(def), kwarg_only);
}
Param parseBareTypeAnnotation() {
type.range(),
Ident::create(type.range(), ""),
type,
- Maybe<Expr>::create(type.range()));
+ Maybe<Expr>::create(type.range()),
+ /*kwarg_only=*/false);
}
Decl parseTypeComment() {
}
}
+ List<Param> parseParams() {
+ auto r = L.cur().range;
+ std::vector<Param> params;
+ bool kwarg_only = false;
+ parseSequence('(', ',', ')', [&] {
+ if (!kwarg_only && L.nextIf('*')) {
+ kwarg_only = true;
+ } else {
+ params.emplace_back(parseParam(kwarg_only));
+ }
+ });
+ return List<Param>::create(r, params);
+ }
Decl parseDecl() {
- auto paramlist = parseList('(', ',', ')', &ParserImpl::parseParam);
// Parse return type annotation
+ List<Param> paramlist = parseParams();
TreeRef return_type;
Maybe<Expr> return_annotation = parseReturnAnnotation();
L.expect(':');
"name", [](const Ident& self) { return self.name(); });
py::class_<Param, TreeView>(m, "Param")
- .def(py::init([](const Expr& type, const Ident& name) {
+ .def(py::init([](const Expr& type, const Ident& name, bool kwarg_only) {
return Param::create(
- name.range(), name, type, Maybe<Expr>::create(name.range()));
+ name.range(),
+ name,
+ type,
+ Maybe<Expr>::create(name.range()),
+ kwarg_only);
}));
py::class_<Attribute, TreeView>(m, "Attribute")
.def(py::init([](const Ident& name, const Expr& value) {
const SourceRange& range,
const Ident& ident,
const Expr& type,
- const Maybe<Expr>& def) {
- return Param(Compound::create(TK_PARAM, range, {ident, type, def}));
+ const Maybe<Expr>& def,
+ bool kwarg_only) {
+ TreeRef kwarg_only_tree =
+ Compound::create(kwarg_only ? TK_TRUE : TK_FALSE, range, {});
+ return Param(
+ Compound::create(TK_PARAM, range, {ident, type, def, kwarg_only_tree}));
}
Ident ident() const {
return Ident(subtree(0));
Maybe<Expr> defaultValue() const {
return Maybe<Expr>(subtree(2));
}
+ bool kwarg_only() const {
+ return TK_TRUE == subtree(3)->kind();
+ }
Param withType(const Expr& typ) const {
- return Param::create(range(), ident(), typ, defaultValue());
+ return Param::create(range(), ident(), typ, defaultValue(), kwarg_only());
}
};
def expand(self,
size: List[int],
+ *,
implicit: bool=False):
self_size = self.size()
def backward(grad_output):
return cache_it->second;
} else {
auto schema_str = canonicalSchemaString(schema);
- // JIT doesn't support keyword only arguments.
- // Remove ' *,' in schema before looking up
- // TODO: #16921 properly support keyword only arguments in JIT.
- auto n = schema_str.find("*, ");
- if (n != std::string::npos) {
- schema_str = schema_str.erase(n, 3);
- }
-
auto sym_script_it = schema_to_graphs.find(schema_str);
+
if (sym_script_it != schema_to_graphs.end()) {
cached_gradient_pairs.emplace_hint(
cache_it, &schema, sym_script_it->second);
_vararg_kwarg_err = ("Compiled functions can't take variable number of arguments "
- "or keyword-only arguments")
+ "or use keyword-only arguments with defaults")
def build_param_list(ctx, py_args):
if py_args.vararg is not None or py_args.kwarg is not None:
raise ValueError(_vararg_kwarg_err)
- if not PY2 and (py_args.kw_defaults or py_args.kwonlyargs):
+ if not PY2 and py_args.kw_defaults:
raise ValueError(_vararg_kwarg_err)
- return [build_param(ctx, arg) for arg in py_args.args]
+ result = [build_param(ctx, arg, False) for arg in py_args.args]
+ if not PY2:
+ result += [build_params(ctx, arg, True) for arg in py_args.kwonlyargs]
+ return result
-def build_param(ctx, py_arg):
+def build_param(ctx, py_arg, kwarg_only):
# NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
# In Python2 py_arg is a Name (Expr subclass)
name = py_arg.id if PY2 else py_arg.arg
annotation_expr = build_expr(ctx, py_arg.annotation)
else:
annotation_expr = Var(Ident(r, 'Tensor'))
- return Param(annotation_expr, Ident(r, name))
+ return Param(annotation_expr, Ident(r, name), kwarg_only)
def get_default_args(fn):