using namespace ::testing;
using namespace caffe2;
-const char* simple = "torch/csrc/deploy/example/generated/simple";
-const char* simpleJit = "torch/csrc/deploy/example/generated/simple_jit";
-
-// TODO(jwtan): Try unifying cmake and buck for getting the path.
-const char* path(const char* envname, const char* path) {
- const char* env = getenv(envname);
- return env ? env : path;
-}
-
+// TODO(T96218435): Enable the following tests in OSS.
TEST(IMethodTest, CallMethod) {
- auto scriptModel = torch::jit::load(path("SIMPLE_JIT", simpleJit));
- auto scriptMethod = scriptModel.get_method("forward");
+ auto script_model = torch::jit::load(getenv("SIMPLE_JIT"));
+ auto script_method = script_model.get_method("forward");
torch::deploy::InterpreterManager manager(3);
- torch::deploy::Package package = manager.load_package(path("SIMPLE", simple));
- auto pyModel = package.load_pickle("model", "model.pkl");
- torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward");
+ torch::deploy::Package p = manager.load_package(getenv("SIMPLE"));
+ auto py_model = p.load_pickle("model", "model.pkl");
+ torch::deploy::PythonMethodWrapper py_method(py_model, "forward");
auto input = torch::ones({10, 20});
- auto outputPy = pyMethod({input});
- auto outputScript = scriptMethod({input});
- EXPECT_TRUE(outputPy.isTensor());
- EXPECT_TRUE(outputScript.isTensor());
- auto outputPyTensor = outputPy.toTensor();
- auto outputScriptTensor = outputScript.toTensor();
-
- EXPECT_TRUE(outputPyTensor.equal(outputScriptTensor));
- EXPECT_EQ(outputPyTensor.numel(), 200);
+ auto output_py = py_method({input});
+ auto output_script = script_method({input});
+ EXPECT_TRUE(output_py.isTensor());
+ EXPECT_TRUE(output_script.isTensor());
+ auto output_py_tensor = output_py.toTensor();
+ auto output_script_tensor = output_script.toTensor();
+
+ EXPECT_TRUE(output_py_tensor.equal(output_script_tensor));
+ EXPECT_EQ(output_py_tensor.numel(), 200);
}
TEST(IMethodTest, GetArgumentNames) {
- auto scriptModel = torch::jit::load(path("SIMPLE_JIT", simpleJit));
+ auto scriptModel = torch::jit::load(getenv("SIMPLE_JIT"));
auto scriptMethod = scriptModel.get_method("forward");
auto& scriptNames = scriptMethod.getArgumentNames();
EXPECT_STREQ(scriptNames[0].c_str(), "input");
torch::deploy::InterpreterManager manager(3);
- torch::deploy::Package package = manager.load_package(path("SIMPLE", simple));
+ torch::deploy::Package package = manager.load_package(getenv("SIMPLE"));
auto pyModel = package.load_pickle("model", "model.pkl");
torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward");