Make name() part of IMethod interface (#63995)
authorWill Constable <whc@fb.com>
Mon, 30 Aug 2021 20:29:51 +0000 (13:29 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 20:31:55 +0000 (13:31 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63995

JIT methods already have name() in their interface, and Py methods have names in their implementation.  I'm adding this for a particular case where someone tried to use name() on a JIT method that we're replacing with an IMethod.

Test Plan: add case to imethod API test

Reviewed By: suo

Differential Revision: D30559401

fbshipit-source-id: 76236721f5cd9a9d9d488ddba12bfdd01d679a2c

test/cpp/api/imethod.cpp
torch/csrc/api/include/torch/imethod.h
torch/csrc/deploy/deploy.h
torch/csrc/jit/api/method.h

index 8673e55..b8c12c6 100644 (file)
@@ -28,6 +28,9 @@ TEST(IMethodTest, CallMethod) {
   auto pyModel = package.load_pickle("model", "model.pkl");
   torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward");
 
+  EXPECT_EQ(scriptMethod.name(), "forward");
+  EXPECT_EQ(pyMethod.name(), "forward");
+
   auto input = torch::ones({10, 20});
   auto outputPy = pyMethod({input});
   auto outputScript = scriptMethod({input});
index af01078..5ab9b83 100644 (file)
@@ -28,6 +28,8 @@ class TORCH_API IMethod {
       std::vector<c10::IValue> args,
       const IValueMap& kwargs = IValueMap()) const = 0;
 
+  virtual const std::string& name() const = 0;
+
   // Returns an ordered list of argument names, possible in both
   // script and python methods.  This is a more portable dependency
   // than a ScriptMethod FunctionSchema, which has more information
index 2036479..f34e4bc 100644 (file)
@@ -232,6 +232,10 @@ class PythonMethodWrapper : public torch::IMethod {
       std::string method_name)
       : model_(std::move(model)), method_name_(std::move(method_name)) {}
 
+  const std::string& name() const override {
+    return method_name_;
+  }
+
   c10::IValue operator()(
       std::vector<c10::IValue> args,
       const IValueMap& kwargs = IValueMap()) const override {
index bcd44a1..3fcc442 100644 (file)
@@ -46,7 +46,7 @@ struct TORCH_API Method : public torch::IMethod {
     return function_->graph();
   }
 
-  const std::string& name() const {
+  const std::string& name() const override {
     return function_->name();
   }