Allow optionals arguments from C++ (#19311)
authorDavid Riazati <davidriazati@fb.com>
Fri, 19 Apr 2019 00:06:09 +0000 (17:06 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 19 Apr 2019 00:15:05 +0000 (17:15 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19311
ghimport-source-id: 699f62eb2bbad53ff2045fb2e217eb1402f2cdc5

Reviewed By: eellison

Differential Revision: D14983059

Pulled By: driazati

fbshipit-source-id: 442f96d6bd2a8ce67807ccad2594b39aae489ca5

aten/src/ATen/core/type.cpp
test/cpp/api/jit.cpp

index 20f617f..9f7ffc3 100644 (file)
@@ -187,6 +187,15 @@ TypePtr attemptToRecoverType(const IValue& ivalue) {
 
 // Checks if input_ivalue is a subvalue of type.
 bool isSubvalueOf(const IValue& ivalue, TypePtr type) {
+  if (auto optional = type->cast<OptionalType>()) {
+    // Unwrap the optional if the ivalue is not none
+    if (ivalue.isNone()) {
+      return true;
+    } else {
+      return isSubvalueOf(ivalue, optional->getElementType());
+    }
+  }
+
   if (ivalue.isTuple()) {
     const auto& ivalue_elem = ivalue.toTuple()->elements();
     auto tuple_type = type->cast<TupleType>();
index ed44f7b..1f60e6d 100644 (file)
@@ -112,3 +112,19 @@ TEST(TorchScriptTest, TestTupleArgMatching) {
   module->run_method("tuple_op", tuple_generic_list);
 
 }
+
+TEST(TorchScriptTest, TestOptionalArgMatching) {
+  auto module = torch::jit::compile(R"JIT(
+      def optional_tuple_op(a: Optional[Tuple[int, str]]):
+        if a is None:
+          return 0
+        else:
+          return a[0]
+    )JIT");
+
+  auto optional_tuple = torch::jit::Tuple::create({2, std::string("hi")});
+
+  ASSERT_EQ(2, module->run_method("optional_tuple_op", optional_tuple));
+  ASSERT_EQ(0, module->run_method("optional_tuple_op", IValue()));
+
+}