Partial support for kwarg_only arguments in script (#17339)
authorZachary DeVito <zdevito@gmail.com>
Thu, 21 Feb 2019 23:24:23 +0000 (15:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Feb 2019 23:27:06 +0000 (15:27 -0800)
Summary:
This provides the minimum necessary to allow derivative formulas for things that have a kwarg only specifier in their schema. Support for non-parser frontend default arguments for kwargs is not completed.
Fixes #16921
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17339

Differential Revision: D14160923

Pulled By: zdevito

fbshipit-source-id: 822e964c5a3fe2806509cf24d9f51c6dc01711c3

test/expect/TestScript.test_python_frontend.expect
test/test_jit.py
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/parser.cpp
torch/csrc/jit/script/python_tree_views.cpp
torch/csrc/jit/script/tree_views.h
torch/csrc/jit/symbolic_script.cpp
torch/jit/frontend.py

index 649b714..0f7fac3 100644 (file)
@@ -5,15 +5,18 @@
       (param
         (ident x)
         (variable (ident Tensor))
-        (option))
+        (option)
+        (False))
       (param
         (ident y)
         (variable (ident Tensor))
-        (option))
+        (option)
+        (False))
       (param
         (ident z)
         (variable (ident Tensor))
-        (option)))
+        (option)
+        (False)))
     (option))
   (list
     (assign
index 76ba1d9..b58a5b9 100644 (file)
@@ -2935,6 +2935,22 @@ class TestScript(JitTestCase):
             # type: (Tensor) -> Tensor
             return fn2(input, [1])
 
+    def test_parser_kwargonly(self):
+        cu = torch.jit.CompilationUnit('''
+            def foo(x, *, y) -> Tuple[Tensor, Tensor]:
+                return x, x
+            def bar(x):
+                return foo(x, y=x)
+        ''')
+        self.assertTrue('*' in cu.module._get_method('foo').pretty_print_schema())
+        with self.assertRaisesRegex(RuntimeError, "not provided"):
+            torch.jit.CompilationUnit('''
+                def foo(x, *, y) -> Tuple[Tensor, Tensor]:
+                    return x, x
+                def bar(x):
+                    return foo(x, x)
+            ''')
+
     def test_annoying_doubles(self):
         mod = types.ModuleType("temp")
         mod.inf = float("inf")
index 40cfc41..dc7f6b3 100644 (file)
@@ -662,7 +662,7 @@ struct to_ir {
           type,
           N,
           default_value,
-          /*kwarg_only =*/false);
+          decl_arg.kwarg_only());
       retval.push_back(arg);
     }
     return retval;
@@ -2349,7 +2349,8 @@ struct to_ir {
     // aten::slice, we should separate it from this function.
     if (dim) {
       AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
-      args.emplace_back(loc, "dim", graph->insertConstant(dim.value(), nullptr, loc));
+      args.emplace_back(
+          loc, "dim", graph->insertConstant(dim.value(), nullptr, loc));
     } else {
       AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
     }
@@ -2366,7 +2367,8 @@ struct to_ir {
         return emitTupleSlice(loc, args[0], args[1], c10::nullopt);
       }
     }
-    NamedValue step = NamedValue(loc, "step", graph->insertConstant(1, nullptr, loc));
+    NamedValue step =
+        NamedValue(loc, "step", graph->insertConstant(1, nullptr, loc));
     return emitBuiltinCall(
         loc, *graph, aten::slice, c10::nullopt, args, {step}, true);
   }
index e8b6c03..b47a0b6 100644 (file)
@@ -1,7 +1,7 @@
+#include <torch/csrc/jit/script/parser.h>
 #include <c10/util/Optional.h>
 #include <torch/csrc/jit/script/lexer.h>
 #include <torch/csrc/jit/script/parse_string_literal.h>
-#include <torch/csrc/jit/script/parser.h>
 #include <torch/csrc/jit/script/tree.h>
 #include <torch/csrc/jit/script/tree_views.h>
 
@@ -243,19 +243,27 @@ struct ParserImpl {
     }
     return Expr(prefix);
   }
-  template <typename T>
-  List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
-    auto r = L.cur().range;
+  void parseSequence(
+      int begin,
+      int sep,
+      int end,
+      const std::function<void()>& parse) {
     if (begin != TK_NOTHING)
       L.expect(begin);
-    std::vector<T> elements;
     if (L.cur().kind != end) {
       do {
-        elements.push_back((this->*parse)());
+        parse();
       } while (L.nextIf(sep));
     }
     if (end != TK_NOTHING)
       L.expect(end);
+  }
+  template <typename T>
+  List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
+    auto r = L.cur().range;
+    std::vector<T> elements;
+    parseSequence(
+        begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
     return List<T>::create(r, elements);
   }
 
@@ -326,7 +334,7 @@ struct ParserImpl {
     return Subscript::create(range, Expr(value), subscript_exprs);
   }
 
-  TreeRef parseParam() {
+  TreeRef parseParam(bool kwarg_only) {
     auto ident = parseIdent();
     TreeRef type;
     if (L.nextIf(':')) {
@@ -341,7 +349,7 @@ struct ParserImpl {
       def = Maybe<Expr>::create(L.cur().range);
     }
     return Param::create(
-        type->range(), Ident(ident), Expr(type), Maybe<Expr>(def));
+        type->range(), Ident(ident), Expr(type), Maybe<Expr>(def), kwarg_only);
   }
 
   Param parseBareTypeAnnotation() {
@@ -350,7 +358,8 @@ struct ParserImpl {
         type.range(),
         Ident::create(type.range(), ""),
         type,
-        Maybe<Expr>::create(type.range()));
+        Maybe<Expr>::create(type.range()),
+        /*kwarg_only=*/false);
   }
 
   Decl parseTypeComment() {
@@ -517,9 +526,22 @@ struct ParserImpl {
     }
   }
 
+  List<Param> parseParams() {
+    auto r = L.cur().range;
+    std::vector<Param> params;
+    bool kwarg_only = false;
+    parseSequence('(', ',', ')', [&] {
+      if (!kwarg_only && L.nextIf('*')) {
+        kwarg_only = true;
+      } else {
+        params.emplace_back(parseParam(kwarg_only));
+      }
+    });
+    return List<Param>::create(r, params);
+  }
   Decl parseDecl() {
-    auto paramlist = parseList('(', ',', ')', &ParserImpl::parseParam);
     // Parse return type annotation
+    List<Param> paramlist = parseParams();
     TreeRef return_type;
     Maybe<Expr> return_annotation = parseReturnAnnotation();
     L.expect(':');
index 192c28f..2035956 100644 (file)
@@ -91,9 +91,13 @@ void initTreeViewBindings(PyObject* module) {
           "name", [](const Ident& self) { return self.name(); });
 
   py::class_<Param, TreeView>(m, "Param")
-      .def(py::init([](const Expr& type, const Ident& name) {
+      .def(py::init([](const Expr& type, const Ident& name, bool kwarg_only) {
         return Param::create(
-            name.range(), name, type, Maybe<Expr>::create(name.range()));
+            name.range(),
+            name,
+            type,
+            Maybe<Expr>::create(name.range()),
+            kwarg_only);
       }));
   py::class_<Attribute, TreeView>(m, "Attribute")
       .def(py::init([](const Ident& name, const Expr& value) {
index c34aaad..853a32f 100644 (file)
@@ -328,8 +328,12 @@ struct Param : public TreeView {
       const SourceRange& range,
       const Ident& ident,
       const Expr& type,
-      const Maybe<Expr>& def) {
-    return Param(Compound::create(TK_PARAM, range, {ident, type, def}));
+      const Maybe<Expr>& def,
+      bool kwarg_only) {
+    TreeRef kwarg_only_tree =
+        Compound::create(kwarg_only ? TK_TRUE : TK_FALSE, range, {});
+    return Param(
+        Compound::create(TK_PARAM, range, {ident, type, def, kwarg_only_tree}));
   }
   Ident ident() const {
     return Ident(subtree(0));
@@ -340,8 +344,11 @@ struct Param : public TreeView {
   Maybe<Expr> defaultValue() const {
     return Maybe<Expr>(subtree(2));
   }
+  bool kwarg_only() const {
+    return TK_TRUE == subtree(3)->kind();
+  }
   Param withType(const Expr& typ) const {
-    return Param::create(range(), ident(), typ, defaultValue());
+    return Param::create(range(), ident(), typ, defaultValue(), kwarg_only());
   }
 };
 
index 84e8da6..394af9b 100644 (file)
@@ -31,6 +31,7 @@ const std::vector<std::string> functions = {
 
         def expand(self,
                    size: List[int],
+                   *,
                    implicit: bool=False):
             self_size = self.size()
             def backward(grad_output):
@@ -491,15 +492,8 @@ c10::optional<GradientPair> gradientInfoForSchema(
     return cache_it->second;
   } else {
     auto schema_str = canonicalSchemaString(schema);
-    // JIT doesn't support keyword only arguments.
-    // Remove ' *,' in schema before looking up
-    // TODO: #16921 properly support keyword only arguments in JIT.
-    auto n = schema_str.find("*, ");
-    if (n != std::string::npos) {
-      schema_str = schema_str.erase(n, 3);
-    }
-
     auto sym_script_it = schema_to_graphs.find(schema_str);
+
     if (sym_script_it != schema_to_graphs.end()) {
       cached_gradient_pairs.emplace_hint(
           cache_it, &schema, sym_script_it->second);
index 856f2d1..eb727ba 100644 (file)
@@ -183,18 +183,21 @@ def build_def(ctx, py_def, type_line, is_method):
 
 
 _vararg_kwarg_err = ("Compiled functions can't take variable number of arguments "
-                     "or keyword-only arguments")
+                     "or use keyword-only arguments with defaults")
 
 
 def build_param_list(ctx, py_args):
     if py_args.vararg is not None or py_args.kwarg is not None:
         raise ValueError(_vararg_kwarg_err)
-    if not PY2 and (py_args.kw_defaults or py_args.kwonlyargs):
+    if not PY2 and py_args.kw_defaults:
         raise ValueError(_vararg_kwarg_err)
-    return [build_param(ctx, arg) for arg in py_args.args]
+    result = [build_param(ctx, arg, False) for arg in py_args.args]
+    if not PY2:
+        result += [build_params(ctx, arg, True) for arg in py_args.kwonlyargs]
+    return result
 
 
-def build_param(ctx, py_arg):
+def build_param(ctx, py_arg, kwarg_only):
     # NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
     #     In Python2 py_arg is a Name (Expr subclass)
     name = py_arg.id if PY2 else py_arg.arg
@@ -203,7 +206,7 @@ def build_param(ctx, py_arg):
         annotation_expr = build_expr(ctx, py_arg.annotation)
     else:
         annotation_expr = Var(Ident(r, 'Tensor'))
-    return Param(annotation_expr, Ident(r, name))
+    return Param(annotation_expr, Ident(r, name), kwarg_only)
 
 
 def get_default_args(fn):