From: Zachary DeVito Date: Thu, 21 Feb 2019 23:24:23 +0000 (-0800) Subject: Partial support for kwarg_only arguments in script (#17339) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1152 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4c6da649e537882c40223c7e1f66403d321a04bc;p=platform%2Fupstream%2Fpytorch.git Partial support for kwarg_only arguments in script (#17339) Summary: This provides the minimum necessary to allow derivative formulas for things that have a kwarg only specifier in their schema. Support for non-parser frontend default arguments for kwargs is not completed. Fixes #16921 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17339 Differential Revision: D14160923 Pulled By: zdevito fbshipit-source-id: 822e964c5a3fe2806509cf24d9f51c6dc01711c3 --- diff --git a/test/expect/TestScript.test_python_frontend.expect b/test/expect/TestScript.test_python_frontend.expect index 649b714..0f7fac3 100644 --- a/test/expect/TestScript.test_python_frontend.expect +++ b/test/expect/TestScript.test_python_frontend.expect @@ -5,15 +5,18 @@ (param (ident x) (variable (ident Tensor)) - (option)) + (option) + (False)) (param (ident y) (variable (ident Tensor)) - (option)) + (option) + (False)) (param (ident z) (variable (ident Tensor)) - (option))) + (option) + (False))) (option)) (list (assign diff --git a/test/test_jit.py b/test/test_jit.py index 76ba1d9..b58a5b9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2935,6 +2935,22 @@ class TestScript(JitTestCase): # type: (Tensor) -> Tensor return fn2(input, [1]) + def test_parser_kwargonly(self): + cu = torch.jit.CompilationUnit(''' + def foo(x, *, y) -> Tuple[Tensor, Tensor]: + return x, x + def bar(x): + return foo(x, y=x) + ''') + self.assertTrue('*' in cu.module._get_method('foo').pretty_print_schema()) + with self.assertRaisesRegex(RuntimeError, "not provided"): + torch.jit.CompilationUnit(''' + def foo(x, *, y) -> Tuple[Tensor, Tensor]: + return x, x + def bar(x): + return foo(x, x) + ''') + def test_annoying_doubles(self): mod = types.ModuleType("temp") mod.inf = float("inf") diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 40cfc41..dc7f6b3 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -662,7 +662,7 @@ struct to_ir { type, N, default_value, - /*kwarg_only =*/false); + decl_arg.kwarg_only()); retval.push_back(arg); } return retval; @@ -2349,7 +2349,8 @@ struct to_ir { // aten::slice, we should separate it from this function. if (dim) { AT_ASSERT(input->type()->isSubtypeOf(TensorType::get())); - args.emplace_back(loc, "dim", graph->insertConstant(dim.value(), nullptr, loc)); + args.emplace_back( + loc, "dim", graph->insertConstant(dim.value(), nullptr, loc)); } else { AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get())); } @@ -2366,7 +2367,8 @@ struct to_ir { return emitTupleSlice(loc, args[0], args[1], c10::nullopt); } } - NamedValue step = NamedValue(loc, "step", graph->insertConstant(1, nullptr, loc)); + NamedValue step = + NamedValue(loc, "step", graph->insertConstant(1, nullptr, loc)); return emitBuiltinCall( loc, *graph, aten::slice, c10::nullopt, args, {step}, true); } diff --git a/torch/csrc/jit/script/parser.cpp b/torch/csrc/jit/script/parser.cpp index e8b6c03..b47a0b6 100644 --- a/torch/csrc/jit/script/parser.cpp +++ b/torch/csrc/jit/script/parser.cpp @@ -1,7 +1,7 @@ +#include #include #include #include -#include #include #include @@ -243,19 +243,27 @@ struct ParserImpl { } return Expr(prefix); } - template - List parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) { - auto r = L.cur().range; + void parseSequence( + int begin, + int sep, + int end, + const std::function& parse) { if (begin != TK_NOTHING) L.expect(begin); - std::vector elements; if (L.cur().kind != end) { do { - elements.push_back((this->*parse)()); + parse(); } while (L.nextIf(sep)); } if (end != TK_NOTHING) L.expect(end); + } + template + List parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) { + auto r = L.cur().range; + std::vector elements; + parseSequence( + begin, sep, end, [&] { elements.emplace_back((this->*parse)()); }); return List::create(r, elements); } @@ -326,7 +334,7 @@ struct ParserImpl { return Subscript::create(range, Expr(value), subscript_exprs); } - TreeRef parseParam() { + TreeRef parseParam(bool kwarg_only) { auto ident = parseIdent(); TreeRef type; if (L.nextIf(':')) { @@ -341,7 +349,7 @@ struct ParserImpl { def = Maybe::create(L.cur().range); } return Param::create( - type->range(), Ident(ident), Expr(type), Maybe(def)); + type->range(), Ident(ident), Expr(type), Maybe(def), kwarg_only); } Param parseBareTypeAnnotation() { @@ -350,7 +358,8 @@ struct ParserImpl { type.range(), Ident::create(type.range(), ""), type, - Maybe::create(type.range())); + Maybe::create(type.range()), + /*kwarg_only=*/false); } Decl parseTypeComment() { @@ -517,9 +526,22 @@ struct ParserImpl { } } + List parseParams() { + auto r = L.cur().range; + std::vector params; + bool kwarg_only = false; + parseSequence('(', ',', ')', [&] { + if (!kwarg_only && L.nextIf('*')) { + kwarg_only = true; + } else { + params.emplace_back(parseParam(kwarg_only)); + } + }); + return List::create(r, params); + } Decl parseDecl() { - auto paramlist = parseList('(', ',', ')', &ParserImpl::parseParam); // Parse return type annotation + List paramlist = parseParams(); TreeRef return_type; Maybe return_annotation = parseReturnAnnotation(); L.expect(':'); diff --git a/torch/csrc/jit/script/python_tree_views.cpp b/torch/csrc/jit/script/python_tree_views.cpp index 192c28f..2035956 100644 --- a/torch/csrc/jit/script/python_tree_views.cpp +++ b/torch/csrc/jit/script/python_tree_views.cpp @@ -91,9 +91,13 @@ void initTreeViewBindings(PyObject* module) { "name", [](const Ident& self) { return self.name(); }); py::class_(m, "Param") - .def(py::init([](const Expr& type, const Ident& name) { + .def(py::init([](const Expr& type, const Ident& name, bool kwarg_only) { return Param::create( - name.range(), name, type, Maybe::create(name.range())); + name.range(), + name, + type, + Maybe::create(name.range()), + kwarg_only); })); py::class_(m, "Attribute") .def(py::init([](const Ident& name, const Expr& value) { diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h index c34aaad..853a32f 100644 --- a/torch/csrc/jit/script/tree_views.h +++ b/torch/csrc/jit/script/tree_views.h @@ -328,8 +328,12 @@ struct Param : public TreeView { const SourceRange& range, const Ident& ident, const Expr& type, - const Maybe& def) { - return Param(Compound::create(TK_PARAM, range, {ident, type, def})); + const Maybe& def, + bool kwarg_only) { + TreeRef kwarg_only_tree = + Compound::create(kwarg_only ? TK_TRUE : TK_FALSE, range, {}); + return Param( + Compound::create(TK_PARAM, range, {ident, type, def, kwarg_only_tree})); } Ident ident() const { return Ident(subtree(0)); @@ -340,8 +344,11 @@ struct Param : public TreeView { Maybe defaultValue() const { return Maybe(subtree(2)); } + bool kwarg_only() const { + return TK_TRUE == subtree(3)->kind(); + } Param withType(const Expr& typ) const { - return Param::create(range(), ident(), typ, defaultValue()); + return Param::create(range(), ident(), typ, defaultValue(), kwarg_only()); } }; diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 84e8da6..394af9b 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -31,6 +31,7 @@ const std::vector functions = { def expand(self, size: List[int], + *, implicit: bool=False): self_size = self.size() def backward(grad_output): @@ -491,15 +492,8 @@ c10::optional gradientInfoForSchema( return cache_it->second; } else { auto schema_str = canonicalSchemaString(schema); - // JIT doesn't support keyword only arguments. - // Remove ' *,' in schema before looking up - // TODO: #16921 properly support keyword only arguments in JIT. - auto n = schema_str.find("*, "); - if (n != std::string::npos) { - schema_str = schema_str.erase(n, 3); - } - auto sym_script_it = schema_to_graphs.find(schema_str); + if (sym_script_it != schema_to_graphs.end()) { cached_gradient_pairs.emplace_hint( cache_it, &schema, sym_script_it->second); diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 856f2d1..eb727ba 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -183,18 +183,21 @@ def build_def(ctx, py_def, type_line, is_method): _vararg_kwarg_err = ("Compiled functions can't take variable number of arguments " - "or keyword-only arguments") + "or use keyword-only arguments with defaults") def build_param_list(ctx, py_args): if py_args.vararg is not None or py_args.kwarg is not None: raise ValueError(_vararg_kwarg_err) - if not PY2 and (py_args.kw_defaults or py_args.kwonlyargs): + if not PY2 and py_args.kw_defaults: raise ValueError(_vararg_kwarg_err) - return [build_param(ctx, arg) for arg in py_args.args] + result = [build_param(ctx, arg, False) for arg in py_args.args] + if not PY2: + result += [build_params(ctx, arg, True) for arg in py_args.kwonlyargs] + return result -def build_param(ctx, py_arg): +def build_param(ctx, py_arg, kwarg_only): # NB: In Python3 py_arg is a pair of (str arg, expr? annotation) # In Python2 py_arg is a Name (Expr subclass) name = py_arg.id if PY2 else py_arg.arg @@ -203,7 +206,7 @@ def build_param(ctx, py_arg): annotation_expr = build_expr(ctx, py_arg.annotation) else: annotation_expr = Var(Ident(r, 'Tensor')) - return Param(annotation_expr, Ident(r, name)) + return Param(annotation_expr, Ident(r, name), kwarg_only) def get_default_args(fn):