// Checks if input_ivalue is a subvalue of type.
bool isSubvalueOf(const IValue& ivalue, TypePtr type) {
+ if (auto optional = type->cast<OptionalType>()) {
+ // Unwrap the optional if the ivalue is not none
+ if (ivalue.isNone()) {
+ return true;
+ } else {
+ return isSubvalueOf(ivalue, optional->getElementType());
+ }
+ }
+
if (ivalue.isTuple()) {
const auto& ivalue_elem = ivalue.toTuple()->elements();
auto tuple_type = type->cast<TupleType>();
module->run_method("tuple_op", tuple_generic_list);
}
+
+TEST(TorchScriptTest, TestOptionalArgMatching) {
+ auto module = torch::jit::compile(R"JIT(
+ def optional_tuple_op(a: Optional[Tuple[int, str]]):
+ if a is None:
+ return 0
+ else:
+ return a[0]
+ )JIT");
+
+ auto optional_tuple = torch::jit::Tuple::create({2, std::string("hi")});
+
+ ASSERT_EQ(2, module->run_method("optional_tuple_op", optional_tuple));
+ ASSERT_EQ(0, module->run_method("optional_tuple_op", IValue()));
+
+}