Support Kwargs in C++ Function/Method calls (#19086)
authorZachary DeVito <zdevito@fb.com>
Sat, 13 Apr 2019 15:28:10 +0000 (08:28 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 13 Apr 2019 15:42:11 +0000 (08:42 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19086
ghimport-source-id: 7790a5cc6e32f6f72e92add0b9f76dfa49ad9859

Reviewed By: jamesr66a

Differential Revision: D14875729

Pulled By: zdevito

fbshipit-source-id: ad1e4542381d9c33722155459e794f1ba4660dbb

aten/src/ATen/core/function_schema.h
aten/src/ATen/core/function_schema_inl.h [new file with mode: 0644]
test/cpp/jit/test.cpp
test/cpp/jit/test_custom_operators.h
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/script/compilation_unit.h
torch/csrc/jit/script/module.h

index 888cfcb..ea9d02d 100644 (file)
@@ -4,6 +4,7 @@
 #include <ATen/core/interned_strings.h>
 #include <ATen/core/ivalue.h>
 #include <ATen/core/alias_info.h>
+#include <unordered_map>
 
 namespace c10 {
 
@@ -125,6 +126,7 @@ private:
   // arguments are not checked by schema
   const bool is_vararg_;
   const bool is_varret_;
+  void checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const;
 
 public:
   const std::string& name() const {
@@ -173,9 +175,14 @@ public:
         is_vararg(),
         is_varret());
   }
+
   // Check that inputs have the correct types and appends any missing default
   // values.
-  void checkAndNormalizeInputs(std::vector<IValue>& inputs) const;
+  void checkAndNormalizeInputs(
+      std::vector<IValue>& inputs,
+      const std::unordered_map<std::string, IValue>& kwargs) const;
+
+  void findErrorInKwargs(const std::vector<std::string>& kwargs) const;
 };
 
 inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
@@ -196,42 +203,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
   return out << arg.type()->str() << " " << arg.name() << (arg.default_value() ? "=<default>" : "");
 }
 
-inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
-  // eventually this should look almost identical to python arg parser, but
-  // it is simpler for now to work directly on this schema
-
-  out << schema.name();
-  out << "(";
-
-  bool seen_kwarg_only = false;
-  for(size_t i = 0; i < schema.arguments().size(); ++i) {
-    if (i > 0) out << ", ";
-    if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
-      out << "*, ";
-      seen_kwarg_only = true;
-    }
-    out << schema.arguments()[i];
-  }
-
-  if(schema.is_vararg()) {
-    if(schema.arguments().size() > 0)
-      out << ", ";
-    out << "...";
-  }
-
-  out << ") -> ";
-  if (schema.returns().size() == 1) {
-    out << schema.returns().at(0).type()->str();
-  } else if (schema.returns().size() > 1) {
-    out << "(";
-    for (size_t i = 0; i < schema.returns().size(); ++i) {
-      if (i > 0) out << ", ";
-      out << schema.returns()[i].type()->str();
-    }
-    out << ")";
-  }
-  return out;
-}
+inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema);
 
 inline std::string toString(const FunctionSchema& schema) {
   std::ostringstream str;
@@ -239,46 +211,6 @@ inline std::string toString(const FunctionSchema& schema) {
   return str.str();
 }
 
-inline void FunctionSchema::checkAndNormalizeInputs(std::vector<IValue>& inputs) const {
-  // Do we have more inputs than the schema accepts?
-  AT_CHECK(
-      inputs.size() <= arguments().size(),
-      "Expected at most ",
-      arguments().size(),
-      " argument(s) for operator '",
-      name(),
-      "', but received ",
-      inputs.size(),
-      " argument(s). Declaration: ",
-      *this);
-
-  for (size_t pos = 0; pos < arguments().size(); ++pos) {
-    const auto& argument = arguments()[pos];
-    if (pos < inputs.size()) {
-      if (!isSubvalueOf(inputs[pos], argument.type())) {
-        AT_ERROR(
-            "Expected value of type ",
-            *argument.type(),
-            " for argument '",
-            argument.name(),
-            "' in position ",
-            pos,
-            ", but instead got value of type ",
-            attemptToRecoverType(inputs[pos])->str(),
-            ". Declaration: ",
-            *this);
-      }
-    } else if (argument.default_value()) {
-      inputs.push_back(*argument.default_value());
-    } else {
-      AT_ERROR(
-          name(),
-          "() is missing value for argument '",
-          argument.name(),
-          "'. Declaration: ",
-          *this);
-    }
-  }
-}
-
 } // namespace c10
+
+#include <ATen/core/function_schema_inl.h>
diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h
new file mode 100644 (file)
index 0000000..a4e1588
--- /dev/null
@@ -0,0 +1,145 @@
+#pragma once
+
+// note: windows build doesn't find symbols in operator files unless
+// this is a header file
+
+namespace c10 {
+
+inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
+  // eventually this should look almost identical to python arg parser, but
+  // it is simpler for now to work directly on this schema
+
+  out << schema.name();
+  out << "(";
+
+  bool seen_kwarg_only = false;
+  for(size_t i = 0; i < schema.arguments().size(); ++i) {
+    if (i > 0) out << ", ";
+    if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
+      out << "*, ";
+      seen_kwarg_only = true;
+    }
+    out << schema.arguments()[i];
+  }
+
+  if(schema.is_vararg()) {
+    if(schema.arguments().size() > 0)
+      out << ", ";
+    out << "...";
+  }
+
+  out << ") -> ";
+  if (schema.returns().size() == 1) {
+    out << schema.returns().at(0).type()->str();
+  } else if (schema.returns().size() > 1) {
+    out << "(";
+    for (size_t i = 0; i < schema.returns().size(); ++i) {
+      if (i > 0) out << ", ";
+      out << schema.returns()[i].type()->str();
+    }
+    out << ")";
+  }
+  return out;
+}
+
+inline void FunctionSchema::checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const {
+  if (!isSubvalueOf(value, argument.type())) {
+    std::string position = pos ? ::c10::str(" in position ", *pos) : "";
+    AT_ERROR(
+        "Expected value of type ",
+        *argument.type(),
+        " for argument '",
+        argument.name(),
+        "'",
+        position,
+        ", but instead got value of type ",
+        attemptToRecoverType(value)->str(),
+        ". Declaration: ",
+        *this);
+  }
+}
+
+inline void FunctionSchema::findErrorInKwargs(const std::vector<std::string>& kwargs) const {
+  // First check if any of the kwargs are unknown, i.e. don't match the name of
+  // any argument in the schema.
+  for (const auto& kwarg : kwargs) {
+    if (!std::count_if(
+            arguments().begin(),
+            arguments().end(),
+            [&kwarg](const Argument& argument) {
+              return argument.name() == kwarg;
+            })) {
+      throw std::runtime_error(c10::str(
+          "Unknown keyword argument '",
+          kwarg,
+          "' for operator '",
+          name(),
+          "'. Schema: ",
+          *this));
+    }
+  }
+  // If there are unconsumed kwargs but none of them were unknown, the first
+  // positional argument present in the kwargs is duplicated.
+  for (const auto& argument : arguments()) {
+    if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) {
+      AT_ASSERT(!argument.default_value());
+      throw std::runtime_error(c10::str(
+          "Argument '",
+          argument.name(),
+          "' specified both as positional and ",
+          "keyword argument. Schema: ",
+          *this));
+    }
+  }
+}
+
+inline void FunctionSchema::checkAndNormalizeInputs(
+    std::vector<IValue>& inputs,
+    const std::unordered_map<std::string, IValue>& kwargs) const {
+  // Do we have more inputs than the schema accepts?
+  AT_CHECK(
+      inputs.size() <= arguments().size(),
+      "Expected at most ",
+      arguments().size(),
+      " argument(s) for operator '",
+      name(),
+      "', but received ",
+      inputs.size(),
+      " argument(s). Declaration: ",
+      *this);
+
+  size_t consumed_kwargs = 0;
+  for (size_t pos = 0; pos < arguments().size(); ++pos) {
+    const auto& argument = arguments()[pos];
+    if (pos < inputs.size()) {
+      checkArg(inputs[pos], argument, pos);
+      continue;
+    }
+    auto it = kwargs.find(argument.name());
+    if (it != kwargs.end()) {
+      checkArg(it->second, argument, nullopt);
+      inputs.push_back(it->second);
+      consumed_kwargs++;
+      continue;
+    }
+    if (argument.default_value()) {
+      inputs.push_back(*argument.default_value());
+      continue;
+    }
+    AT_ERROR(
+        name(),
+        "() is missing value for argument '",
+        argument.name(),
+        "'. Declaration: ",
+        *this);
+  }
+  if (consumed_kwargs != kwargs.size()) {
+    std::vector<std::string> names;
+    for(const auto& k : kwargs) {
+      names.emplace_back(k.first);
+    }
+    findErrorInKwargs(names);
+  }
+}
+
+} // namespace c10
index a15cf00..0b9238c 100644 (file)
@@ -41,6 +41,7 @@ namespace jit {
   _(CreateAutodiffSubgraphs)       \
   _(CustomOperators)               \
   _(CustomOperatorAliasing)        \
+  _(IValueKWargs)                  \
   _(Differentiate)                 \
   _(DifferentiateWithRequiresGrad) \
   _(DynamicDAG)                    \
index bc0e071..cd52d90 100644 (file)
@@ -7,6 +7,7 @@
 #include "torch/csrc/jit/irparser.h"
 #include "torch/csrc/jit/passes/alias_analysis.h"
 #include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "torch/jit.h"
 
 namespace torch {
 namespace jit {
@@ -262,6 +263,17 @@ graph(%x: Tensor, %y: Tensor):
     testing::FileCheck().run(text, *graph);
   }
 }
+
+void testIValueKWargs() {
+  const auto text = R"(
+    def foo(a : int, b : int, c : int = 4):
+      return a + 2*b + 3*c
+  )";
+  auto cu = compile(text);
+  auto result = cu->get_function("foo")({1}, {{"b", 3}});
+  ASSERT_EQ(result.toInt(), 19);
+}
+
 } // namespace test
 } // namespace jit
 } // namespace torch
index 1bf2235..88bd8c6 100644 (file)
 
 namespace torch {
 namespace jit {
-namespace detail {
-
-using ::c10::Argument;
-using ::c10::FunctionSchema;
 
 // error reporting: when reporting user-caused errors, these functions should
 // not use AT_ERROR macros, since these macros add stack trace information
 // that is confusing to display to the end user since it always reports
 // locations in libtorch code rather than user code.
 
-inline void findErrorInKwargs(const FunctionSchema& schema, py::kwargs kwargs) {
-  const auto& arguments = schema.arguments();
-  // First check if any of the kwargs are unknown, i.e. don't match the name of
-  // any argument in the schema.
-  for (const auto& kwarg : kwargs) {
-    const auto key = py::cast<std::string>(kwarg.first);
-    if (!std::count_if(
-            arguments.begin(),
-            arguments.end(),
-            [&key](const Argument& argument) {
-              return argument.name() == key;
-            })) {
-      throw std::runtime_error(c10::str(
-          "Unknown keyword argument '",
-          key,
-          "' for operator '",
-          schema.name(),
-          "'. Schema: ",
-          schema));
-    }
-  }
-  // If there are unconsumed kwargs but none of them were unknown, the first
-  // positional argument present in the kwargs is duplicated.
-  for (const auto& argument : arguments) {
-    if (kwargs.contains(argument.name().c_str())) {
-      AT_ASSERT(!argument.default_value());
-      throw std::runtime_error(c10::str(
-          "Argument '",
-          argument.name(),
-          "' specified both as positional and ",
-          "keyword argument. Schema: ",
-          schema));
-    }
-  }
-}
-} // namespace detail
-
 inline IValue toIValue(py::handle input) {
   if (THPVariable_Check(input.ptr())) {
     auto ten = py::cast<at::Tensor>(input);
@@ -347,11 +306,10 @@ inline py::object toPyObject(IValue&& ivalue) {
     const auto obj = ivalue.toObject();
     const auto classType = ClassType::get(obj->name());
     AT_ASSERT(classType);
-    auto pyClass = py::module::import("torch.jit")
-                       .attr("_get_script_class")(obj->name());
+    auto pyClass =
+        py::module::import("torch.jit").attr("_get_script_class")(obj->name());
     auto pyObj = pyClass.attr("__new__")(pyClass);
 
-
     const auto numAttrs = classType->numAttributes();
 
     for (size_t slot = 0; slot < numAttrs; slot++) {
@@ -436,7 +394,11 @@ inline Stack createStackForSchema(
   }
 
   if (consumed_kwargs != kwargs.size()) {
-    detail::findErrorInKwargs(schema, kwargs);
+    std::vector<std::string> names;
+    for (const auto& kwarg : kwargs) {
+      names.emplace_back(py::cast<std::string>(kwarg.first));
+    }
+    schema.findErrorInKwargs(names);
   }
 
   return stack;
index 790d061..345431b 100644 (file)
@@ -27,6 +27,7 @@ namespace script {
 struct Def;
 struct SugaredValue;
 struct Function;
+using Kwargs = std::unordered_map<std::string, IValue>;
 
 using Resolver = std::function<std::shared_ptr<SugaredValue>(
     const std::string& name,
@@ -57,8 +58,10 @@ struct TORCH_API Function {
     run(stack);
   }
 
-  IValue operator()(std::vector<IValue> stack) {
-    getSchema().checkAndNormalizeInputs(stack);
+  IValue operator()(
+      std::vector<IValue> stack,
+      const Kwargs& kwargs = Kwargs()) {
+    getSchema().checkAndNormalizeInputs(stack, kwargs);
     run(stack);
     return stack.front();
   }
@@ -183,7 +186,6 @@ struct TORCH_API Function {
   mutable std::unique_ptr<FunctionSchema> schema_;
 };
 
-
 // A CompilationUnit is a list of named script::Functions
 // with helper methods to iterate the list, or invoke the function.
 // Classes have a CompilationUnit holding the class methods
index ef4fd25..10923f0 100644 (file)
@@ -78,8 +78,10 @@ struct TORCH_API Method {
     run(stack);
   }
 
-  IValue operator()(std::vector<IValue> stack) {
-    getSchema().checkAndNormalizeInputs(stack);
+  IValue operator()(
+      std::vector<IValue> stack,
+      const Kwargs& kwargs = Kwargs()) {
+    getSchema().checkAndNormalizeInputs(stack, kwargs);
     for (auto input : initial_ivalues_) {
       push(stack, input.value());
     }