Improve IMethod::getArgumentNames to deal with empty argument names list (#62947)
authorJiewen Tan <jwtan@fb.com>
Wed, 11 Aug 2021 23:42:34 +0000 (16:42 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 11 Aug 2021 23:44:00 +0000 (16:44 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62947

This diff improved IMethod::getArgumentNames to deal with empty argument names list.

Test Plan:
buck test mode/dev //caffe2/caffe2/fb/predictor:pytorch_predictor_test -- PyTorchDeployPredictor.GetEmptyArgumentNamesValidationMode
buck test mode/dev //caffe2/caffe2/fb/predictor:pytorch_predictor_test -- PyTorchDeployPredictor.GetEmptyArgumentNamesRealMode

Reviewed By: wconstab

Differential Revision: D30179974

fbshipit-source-id: c7aec35c360a73318867c5b77ebfec3affee47e3

torch/csrc/api/include/torch/imethod.h
torch/csrc/api/src/imethod.cpp
torch/csrc/deploy/deploy.cpp
torch/csrc/jit/python/pybind_utils.h

index db6ca419b682f5e58bca7e1d9fb58a315bfb7e9b..dfabf50ce71913c793ac196f9f43d6ca1fe2a097 100644 (file)
@@ -38,6 +38,7 @@ class IMethod {
   virtual void setArgumentNames(std::vector<std::string>& argumentNames) const = 0;
 
  private:
+  mutable  bool isArgumentNamesInitialized_ { false };
   mutable std::vector<std::string> argumentNames_;
 };
 
index e50101d83ca41b2a17369b238e778c011f8516d9..cfbf1c9e2805dfb6ffc1bc75ccfe5f6e36018af3 100644 (file)
@@ -4,11 +4,11 @@ namespace torch {
 
 const std::vector<std::string>& IMethod::getArgumentNames() const
 {
-  // TODO(jwtan): Deal with empty parameter list.
-  if (!argumentNames_.empty()) {
+  if (isArgumentNamesInitialized_) {
     return argumentNames_;
   }
 
+  isArgumentNamesInitialized_ = true;
   setArgumentNames(argumentNames_);
   return argumentNames_;
 }
index 115af196a23103f5a8186f5cfdd7aed547b4205d..b90b95d6e36c916fb7d08b9ec34ff9eac25a631b 100644 (file)
@@ -78,11 +78,18 @@ InterpreterManager::InterpreterManager(size_t n_interp) : resources_(n_interp) {
   }
 
   // Pre-registered modules.
+  // Since torch::deploy::Obj.toIValue cannot infer empty list, we hack it to
+  // return None for empty list.
   // TODO(jwtan): Make the discovery of these modules easier.
   register_module_source(
       "GetArgumentNamesModule",
       "from inspect import signature\n"
-      "def getArgumentNames(function): return list(signature(function).parameters.keys())\n");
+      "from typing import Callable, Optional\n"
+      "def getArgumentNames(function: Callable) -> Optional[list]:\n"
+      "    names = list(signature(function).parameters.keys())\n"
+      "    if len(names) == 0:\n"
+      "        return None\n"
+      "    return names\n");
   TORCH_DEPLOY_SAFE_CATCH_RETHROW
 }
 
@@ -291,6 +298,10 @@ void PythonMethodWrapper::setArgumentNames(
   auto iArgumentNames =
       session.global("GetArgumentNamesModule", "getArgumentNames")({method})
           .toIValue();
+  if (iArgumentNames.isNone()) {
+    return;
+  }
+
   TORCH_INTERNAL_ASSERT(iArgumentNames.isList());
   auto argumentNames = iArgumentNames.toListRef();
 
index 839a658648fcbc9f2f4b1da838e14aa1a1ab9a66..0138231d3bc3f8cb08381922b21a1a3a84abdb78 100644 (file)
@@ -279,7 +279,6 @@ InferredType tryToInferContainerType(py::handle input);
 
 // Try to infer the type of a Python object
 // The type cannot be inferred if:
-//   input is a None
 //   input is an empty container (list, dict)
 //   input is an list with element types that cannot be unified
 //   input is an dict with key or value types that cannot be unified