From: David Riazati Date: Fri, 19 Apr 2019 00:06:09 +0000 (-0700) Subject: Allow optionals arguments from C++ (#19311) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~152 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d9052b21764075981b0a3cd779d506595882bd27;p=platform%2Fupstream%2Fpytorch.git Allow optionals arguments from C++ (#19311) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19311 ghimport-source-id: 699f62eb2bbad53ff2045fb2e217eb1402f2cdc5 Reviewed By: eellison Differential Revision: D14983059 Pulled By: driazati fbshipit-source-id: 442f96d6bd2a8ce67807ccad2594b39aae489ca5 --- diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 20f617f..9f7ffc3 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -187,6 +187,15 @@ TypePtr attemptToRecoverType(const IValue& ivalue) { // Checks if input_ivalue is a subvalue of type. bool isSubvalueOf(const IValue& ivalue, TypePtr type) { + if (auto optional = type->cast()) { + // Unwrap the optional if the ivalue is not none + if (ivalue.isNone()) { + return true; + } else { + return isSubvalueOf(ivalue, optional->getElementType()); + } + } + if (ivalue.isTuple()) { const auto& ivalue_elem = ivalue.toTuple()->elements(); auto tuple_type = type->cast(); diff --git a/test/cpp/api/jit.cpp b/test/cpp/api/jit.cpp index ed44f7b..1f60e6d 100644 --- a/test/cpp/api/jit.cpp +++ b/test/cpp/api/jit.cpp @@ -112,3 +112,19 @@ TEST(TorchScriptTest, TestTupleArgMatching) { module->run_method("tuple_op", tuple_generic_list); } + +TEST(TorchScriptTest, TestOptionalArgMatching) { + auto module = torch::jit::compile(R"JIT( + def optional_tuple_op(a: Optional[Tuple[int, str]]): + if a is None: + return 0 + else: + return a[0] + )JIT"); + + auto optional_tuple = torch::jit::Tuple::create({2, std::string("hi")}); + + ASSERT_EQ(2, module->run_method("optional_tuple_op", optional_tuple)); + ASSERT_EQ(0, module->run_method("optional_tuple_op", IValue())); + +}