From 056cfaf3ff3dc3672e9923237b5e3867e0c76040 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Tue, 18 Dec 2018 10:27:26 -0800 Subject: [PATCH] Method returns a single argument (#15289) Summary: This PR changes Method (just Method not all graphs) to always have a single return argument. This is part 1 in a set of changes that will enable us to have better handling if early return statements. The simplification that this change provides greatly reduces the work for the next step. This change makes it so that Method and Python handle multiple returns in the same way: * 0 - None * 1 - * many - Tuple[...] The result is that a lot of special-case handling in compiler.cpp and its bindings can be removed. It also fixes several bugs in return handling, including one where return values were not always checked against their attributed values. Notes: * inferTypeFrom is renamed to be more accurate and discourage use. * This has uncovered some bugs in other components, which are noted in the diff. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15289 Differential Revision: D13481649 Pulled By: zdevito fbshipit-source-id: 0e2242a40bb28cca2d0e8be48bede96195e4858c --- aten/src/ATen/core/jit_type.h | 2 +- aten/src/ATen/core/type.cpp | 12 +- test/expect/TestBatched.test_for.expect | 3 +- test/expect/TestBatched.test_if_else.expect | 41 +++--- .../TestBatched.test_if_else_with_scalar.expect | 40 +++--- test/expect/TestBatched.test_if_noelse.expect | 37 +++-- .../TestBatched.test_if_noelse_with_scalar.expect | 36 ++--- test/expect/TestBatched.test_while.expect | 68 +++++---- .../expect/TestFuser.test_lstm_cuda-forward.expect | 3 +- test/expect/TestFuser.test_lstm_traced_cuda.expect | 3 +- .../TestFuser.test_milstm_cuda-forward.expect | 3 +- test/expect/TestJit.test_cu_escaped_number.expect | 3 +- ...estJit.test_pretty_printer-loop_use_test.expect | 2 +- ...Jit.test_pretty_printer-print_weird_test.expect | 3 +- test/expect/TestJit.test_repeated_output.expect | 3 +- test/expect/TestJit.test_trace_tuple.expect | 3 +- .../expect/TestScript.test_augmented_assign.expect | 3 +- .../expect/TestScript.test_constant_pooling.expect | 3 +- test/expect/TestScript.test_if_supertype.expect | 3 +- .../TestScript.test_mutable_dce_graph_input.expect | 3 +- test/expect/TestScript.test_python_frontend.expect | 3 +- test/expect/TestScript.test_tuple_indexing.expect | 3 +- test/expect/TestScript.test_tuple_slicing.expect | 3 +- test/test_jit.py | 75 +++++----- torch/csrc/jit/init.cpp | 3 + torch/csrc/jit/passes/python_print.cpp | 19 ++- torch/csrc/jit/passes/to_batch.cpp | 46 ++++-- torch/csrc/jit/passes/to_batch.h | 2 +- torch/csrc/jit/pybind_utils.h | 2 +- torch/csrc/jit/python_ir.cpp | 3 +- torch/csrc/jit/python_tracer.cpp | 5 +- torch/csrc/jit/script/compiler.cpp | 159 ++++++++++----------- torch/csrc/jit/script/compiler.h | 2 + torch/csrc/jit/script/init.cpp | 6 +- torch/csrc/jit/script/module.h | 23 ++- torch/csrc/jit/script/parser.cpp | 27 ++-- torch/csrc/jit/script/python_tree_views.cpp | 4 +- torch/csrc/jit/script/tree_views.h | 8 +- torch/csrc/jit/tracer.h | 2 +- torch/jit/frontend.py | 3 +- 40 files changed, 364 insertions(+), 308 deletions(-) diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index c057963..000c6fa 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -916,7 +916,7 @@ template<> inline TypePtr getTypePtr>() { return ListTyp template<> inline TypePtr getTypePtr>() { return ListType::ofFloats(); } template<> inline TypePtr getTypePtr>() { return ListType::ofInts(); } -CAFFE2_API TypePtr inferTypeFrom(const IValue& value); +CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value); using TypeEnv = std::unordered_map; struct MatchTypeReturn { diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index e381a71..9c69a49 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -116,7 +116,13 @@ ListTypePtr ListType::ofBools() { return value; } -TypePtr inferTypeFrom(const IValue& value) { +// why incomplete? You cannot completely recover a type from +// an IValue, List[List[int]] and List[List[Tensor]] will both +// become ivalue.isGenericList() and cannot be recovered. +// The only appropriate place to use this is where you know that +// you are only dealing with a subset of objects where you can recover +// the type, like in the tracer. +TypePtr incompleteInferTypeFrom(const IValue& value) { if (value.isTensor()) { return CompleteTensorType::create(value.toTensor()); } else if (value.isDouble()) { @@ -136,11 +142,11 @@ TypePtr inferTypeFrom(const IValue& value) { } else if (value.isDoubleList()) { return ListType::ofFloats(); } else if (value.isTuple()) { - return TupleType::create(fmap(value.toTuple()->elements(), inferTypeFrom)); + return TupleType::create(fmap(value.toTuple()->elements(), incompleteInferTypeFrom)); } else if (value.isDevice()) { return DeviceObjType::get(); } - AT_ASSERTM(false, "Unhandled IValue kind in inferTypeFrom"); + AT_ERROR("Type cannot be accurately recovered from this IValue."); } c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2) { diff --git a/test/expect/TestBatched.test_for.expect b/test/expect/TestBatched.test_for.expect index 9bf819d..ac0e01b 100644 --- a/test/expect/TestBatched.test_for.expect +++ b/test/expect/TestBatched.test_for.expect @@ -17,5 +17,6 @@ graph(%x.1_data : Tensor %data : Tensor = aten::where(%mask, %data.1, %5_data) -> (%7, %data, %mask, %dims) } - return (%x, %10, %11); + %22 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%x, %10, %11) + return (%22); } diff --git a/test/expect/TestBatched.test_if_else.expect b/test/expect/TestBatched.test_if_else.expect index f4ae4ec..0a1261d 100644 --- a/test/expect/TestBatched.test_if_else.expect +++ b/test/expect/TestBatched.test_if_else.expect @@ -7,33 +7,31 @@ graph(%a.1_data : Tensor %6 : int = prim::Constant[value=1]() %7 : Tensor = aten::gt(%a.1_data, %b_data) %8 : Tensor = aten::mul(%a.1_mask, %b_mask) - %9 : Tensor = aten::__or__(%a.1_dims, %b_dims) - %10 : bool = prim::Bool(%7) - %11 : Long() = prim::NumToTensor(%6) - %alpha.1 : float = prim::Float(%11) + %9 : Long() = prim::NumToTensor(%6) + %alpha.1 : float = prim::Float(%9) %data.1 : Tensor = aten::add(%a.1_data, %b_data, %alpha.1) %mask.1 : Tensor = aten::mul(%a.1_mask, %b_mask) %dims.1 : Tensor = aten::__or__(%a.1_dims, %b_dims) - %16 : Long() = prim::NumToTensor(%6) - %alpha : float = prim::Float(%16) + %14 : Long() = prim::NumToTensor(%6) + %alpha : float = prim::Float(%14) %data : Tensor = aten::sub(%a.1_data, %b_data, %alpha) %mask : Tensor = aten::mul(%a.1_mask, %b_mask) %dims : Tensor = aten::__or__(%a.1_dims, %b_dims) - %21 : bool = prim::Constant[value=1]() - %22 : int = prim::Constant[value=1]() - %23 : Tensor = aten::type_as(%8, %7) - %data.2 : Tensor = aten::mul(%7, %23) - %25 : int = aten::dim(%data.2) - %26 : bool = aten::eq(%25, %22) - %cond_data : Tensor, %cond_mask : Tensor = prim::If(%26) + %19 : bool = prim::Constant[value=1]() + %20 : int = prim::Constant[value=1]() + %21 : Tensor = aten::type_as(%8, %7) + %data.2 : Tensor = aten::mul(%7, %21) + %23 : int = aten::dim(%data.2) + %24 : bool = aten::eq(%23, %20) + %cond_data : Tensor, %cond_mask : Tensor = prim::If(%24) block0() { - %29 : int = aten::dim(%data.1) - %30 : int = aten::sub(%29, %22) - %data.4 : Tensor = prim::Loop(%30, %21, %data.2) - block0(%32 : int, %33 : Tensor) { - %34 : int = aten::dim(%33) - %data.3 : Tensor = aten::unsqueeze(%33, %34) - -> (%21, %data.3) + %27 : int = aten::dim(%data.1) + %28 : int = aten::sub(%27, %20) + %data.4 : Tensor = prim::Loop(%28, %19, %data.2) + block0(%30 : int, %31 : Tensor) { + %32 : int = aten::dim(%31) + %data.3 : Tensor = aten::unsqueeze(%31, %32) + -> (%19, %data.3) } %cond_data.1 : Tensor = aten::expand_as(%data.4, %data.1) %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask.1) @@ -45,5 +43,6 @@ graph(%a.1_data : Tensor %res_data : Tensor = aten::where(%cond_data, %data.1, %data) %res_mask : Tensor = aten::where(%cond_mask, %mask.1, %mask) %res_dims : Tensor = aten::__or__(%dims.1, %dims) - return (%res_data, %res_mask, %res_dims); + %39 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims) + return (%39); } diff --git a/test/expect/TestBatched.test_if_else_with_scalar.expect b/test/expect/TestBatched.test_if_else_with_scalar.expect index 5bbf309..e1ba887 100644 --- a/test/expect/TestBatched.test_if_else_with_scalar.expect +++ b/test/expect/TestBatched.test_if_else_with_scalar.expect @@ -9,32 +9,31 @@ graph(%a.1_data : Tensor %8 : Float() = prim::NumToTensor(%7) %other : float = prim::Float(%8) %10 : Tensor = aten::gt(%a.1_data, %other) - %11 : bool = prim::Bool(%10) - %12 : Long() = prim::NumToTensor(%6) - %alpha.1 : float = prim::Float(%12) + %11 : Long() = prim::NumToTensor(%6) + %alpha.1 : float = prim::Float(%11) %data.1 : Tensor = aten::add(%a.1_data, %b_data, %alpha.1) %mask.1 : Tensor = aten::mul(%a.1_mask, %b_mask) %dims.1 : Tensor = aten::__or__(%a.1_dims, %b_dims) - %17 : Long() = prim::NumToTensor(%6) - %alpha : float = prim::Float(%17) + %16 : Long() = prim::NumToTensor(%6) + %alpha : float = prim::Float(%16) %data : Tensor = aten::sub(%a.1_data, %b_data, %alpha) %mask : Tensor = aten::mul(%a.1_mask, %b_mask) %dims : Tensor = aten::__or__(%a.1_dims, %b_dims) - %22 : bool = prim::Constant[value=1]() - %23 : int = prim::Constant[value=1]() - %24 : Tensor = aten::type_as(%a.1_mask, %10) - %data.2 : Tensor = aten::mul(%10, %24) - %26 : int = aten::dim(%data.2) - %27 : bool = aten::eq(%26, %23) - %cond_data : Tensor, %cond_mask : Tensor = prim::If(%27) + %21 : bool = prim::Constant[value=1]() + %22 : int = prim::Constant[value=1]() + %23 : Tensor = aten::type_as(%a.1_mask, %10) + %data.2 : Tensor = aten::mul(%10, %23) + %25 : int = aten::dim(%data.2) + %26 : bool = aten::eq(%25, %22) + %cond_data : Tensor, %cond_mask : Tensor = prim::If(%26) block0() { - %30 : int = aten::dim(%data.1) - %31 : int = aten::sub(%30, %23) - %data.4 : Tensor = prim::Loop(%31, %22, %data.2) - block0(%33 : int, %34 : Tensor) { - %35 : int = aten::dim(%34) - %data.3 : Tensor = aten::unsqueeze(%34, %35) - -> (%22, %data.3) + %29 : int = aten::dim(%data.1) + %30 : int = aten::sub(%29, %22) + %data.4 : Tensor = prim::Loop(%30, %21, %data.2) + block0(%32 : int, %33 : Tensor) { + %34 : int = aten::dim(%33) + %data.3 : Tensor = aten::unsqueeze(%33, %34) + -> (%21, %data.3) } %cond_data.1 : Tensor = aten::expand_as(%data.4, %data.1) %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask.1) @@ -46,5 +45,6 @@ graph(%a.1_data : Tensor %res_data : Tensor = aten::where(%cond_data, %data.1, %data) %res_mask : Tensor = aten::where(%cond_mask, %mask.1, %mask) %res_dims : Tensor = aten::__or__(%dims.1, %dims) - return (%res_data, %res_mask, %res_dims); + %41 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims) + return (%41); } diff --git a/test/expect/TestBatched.test_if_noelse.expect b/test/expect/TestBatched.test_if_noelse.expect index c749a9b..c5eb2ef 100644 --- a/test/expect/TestBatched.test_if_noelse.expect +++ b/test/expect/TestBatched.test_if_noelse.expect @@ -7,28 +7,26 @@ graph(%a.1_data : Tensor %6 : int = prim::Constant[value=1]() %7 : Tensor = aten::gt(%a.1_data, %b_data) %8 : Tensor = aten::mul(%a.1_mask, %b_mask) - %9 : Tensor = aten::__or__(%a.1_dims, %b_dims) - %10 : bool = prim::Bool(%7) - %11 : Long() = prim::NumToTensor(%6) - %alpha : float = prim::Float(%11) + %9 : Long() = prim::NumToTensor(%6) + %alpha : float = prim::Float(%9) %data : Tensor = aten::add(%a.1_data, %b_data, %alpha) %mask : Tensor = aten::mul(%a.1_mask, %b_mask) %dims : Tensor = aten::__or__(%a.1_dims, %b_dims) - %16 : bool = prim::Constant[value=1]() - %17 : int = prim::Constant[value=1]() - %18 : Tensor = aten::type_as(%8, %7) - %data.2 : Tensor = aten::mul(%7, %18) - %20 : int = aten::dim(%data.2) - %21 : bool = aten::eq(%20, %17) - %cond_data : Tensor, %cond_mask : Tensor = prim::If(%21) + %14 : bool = prim::Constant[value=1]() + %15 : int = prim::Constant[value=1]() + %16 : Tensor = aten::type_as(%8, %7) + %data.2 : Tensor = aten::mul(%7, %16) + %18 : int = aten::dim(%data.2) + %19 : bool = aten::eq(%18, %15) + %cond_data : Tensor, %cond_mask : Tensor = prim::If(%19) block0() { - %24 : int = aten::dim(%data) - %25 : int = aten::sub(%24, %17) - %data.4 : Tensor = prim::Loop(%25, %16, %data.2) - block0(%27 : int, %28 : Tensor) { - %29 : int = aten::dim(%28) - %data.3 : Tensor = aten::unsqueeze(%28, %29) - -> (%16, %data.3) + %22 : int = aten::dim(%data) + %23 : int = aten::sub(%22, %15) + %data.4 : Tensor = prim::Loop(%23, %14, %data.2) + block0(%25 : int, %26 : Tensor) { + %27 : int = aten::dim(%26) + %data.3 : Tensor = aten::unsqueeze(%26, %27) + -> (%14, %data.3) } %cond_data.1 : Tensor = aten::expand_as(%data.4, %data) %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask) @@ -40,5 +38,6 @@ graph(%a.1_data : Tensor %res_data : Tensor = aten::where(%cond_data, %data, %a.1_data) %res_mask : Tensor = aten::where(%cond_mask, %mask, %a.1_mask) %res_dims : Tensor = aten::__or__(%dims, %a.1_dims) - return (%res_data, %res_mask, %res_dims); + %34 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims) + return (%34); } diff --git a/test/expect/TestBatched.test_if_noelse_with_scalar.expect b/test/expect/TestBatched.test_if_noelse_with_scalar.expect index ef04cab..0974909 100644 --- a/test/expect/TestBatched.test_if_noelse_with_scalar.expect +++ b/test/expect/TestBatched.test_if_noelse_with_scalar.expect @@ -9,27 +9,26 @@ graph(%a.1_data : Tensor %8 : Float() = prim::NumToTensor(%7) %other : float = prim::Float(%8) %10 : Tensor = aten::gt(%a.1_data, %other) - %11 : bool = prim::Bool(%10) - %12 : Long() = prim::NumToTensor(%6) - %alpha : float = prim::Float(%12) + %11 : Long() = prim::NumToTensor(%6) + %alpha : float = prim::Float(%11) %data : Tensor = aten::add(%a.1_data, %b_data, %alpha) %mask : Tensor = aten::mul(%a.1_mask, %b_mask) %dims : Tensor = aten::__or__(%a.1_dims, %b_dims) - %17 : bool = prim::Constant[value=1]() - %18 : int = prim::Constant[value=1]() - %19 : Tensor = aten::type_as(%a.1_mask, %10) - %data.2 : Tensor = aten::mul(%10, %19) - %21 : int = aten::dim(%data.2) - %22 : bool = aten::eq(%21, %18) - %cond_data : Tensor, %cond_mask : Tensor = prim::If(%22) + %16 : bool = prim::Constant[value=1]() + %17 : int = prim::Constant[value=1]() + %18 : Tensor = aten::type_as(%a.1_mask, %10) + %data.2 : Tensor = aten::mul(%10, %18) + %20 : int = aten::dim(%data.2) + %21 : bool = aten::eq(%20, %17) + %cond_data : Tensor, %cond_mask : Tensor = prim::If(%21) block0() { - %25 : int = aten::dim(%data) - %26 : int = aten::sub(%25, %18) - %data.4 : Tensor = prim::Loop(%26, %17, %data.2) - block0(%28 : int, %29 : Tensor) { - %30 : int = aten::dim(%29) - %data.3 : Tensor = aten::unsqueeze(%29, %30) - -> (%17, %data.3) + %24 : int = aten::dim(%data) + %25 : int = aten::sub(%24, %17) + %data.4 : Tensor = prim::Loop(%25, %16, %data.2) + block0(%27 : int, %28 : Tensor) { + %29 : int = aten::dim(%28) + %data.3 : Tensor = aten::unsqueeze(%28, %29) + -> (%16, %data.3) } %cond_data.1 : Tensor = aten::expand_as(%data.4, %data) %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask) @@ -41,5 +40,6 @@ graph(%a.1_data : Tensor %res_data : Tensor = aten::where(%cond_data, %data, %a.1_data) %res_mask : Tensor = aten::where(%cond_mask, %mask, %a.1_mask) %res_dims : Tensor = aten::__or__(%dims, %a.1_dims) - return (%res_data, %res_mask, %res_dims); + %36 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims) + return (%36); } diff --git a/test/expect/TestBatched.test_while.expect b/test/expect/TestBatched.test_while.expect index b8f9d03..4ee7999 100644 --- a/test/expect/TestBatched.test_while.expect +++ b/test/expect/TestBatched.test_while.expect @@ -9,38 +9,35 @@ graph(%a.1_data : Tensor %8 : Tensor = aten::gt(%a.1_data, %b_data) %9 : Tensor = aten::mul(%a.1_mask, %b_mask) %10 : Tensor = aten::__or__(%a.1_dims, %b_dims) - %11 : bool = prim::Bool(%8) - %12 : int = prim::Constant[value=0]() - %13 : Tensor = aten::mul(%8, %9) - %14 : Tensor = aten::sum(%13) - %15 : Tensor = aten::gt(%14, %12) - %16 : bool = prim::Bool(%15) - %17 : Tensor, %18 : Tensor, %19 : Tensor, %a : Tensor, %21 : Tensor, %22 : Tensor = prim::Loop(%7, %16, %8, %9, %10, %a.1_data, %a.1_mask, %a.1_dims) - block0(%loop_num : int, %cond_data.2 : Tensor, %cond_mask.2 : Tensor, %cond_dims : Tensor, %6_data : Tensor, %6_mask : Tensor, %6_dims : Tensor) { - %30 : Long() = prim::NumToTensor(%6) - %alpha : float = prim::Float(%30) + %11 : int = prim::Constant[value=0]() + %12 : Tensor = aten::mul(%8, %9) + %13 : Tensor = aten::sum(%12) + %14 : Tensor = aten::gt(%13, %11) + %15 : bool = prim::Bool(%14) + %16 : Tensor, %17 : Tensor, %a : Tensor, %19 : Tensor, %20 : Tensor = prim::Loop(%7, %15, %8, %9, %a.1_data, %a.1_mask, %a.1_dims) + block0(%loop_num : int, %cond_data.2 : Tensor, %cond_mask.2 : Tensor, %6_data : Tensor, %6_mask : Tensor, %6_dims : Tensor) { + %27 : Long() = prim::NumToTensor(%6) + %alpha : float = prim::Float(%27) %data : Tensor = aten::sub(%6_data, %b_data, %alpha) %mask : Tensor = aten::mul(%6_mask, %b_mask) %dims : Tensor = aten::__or__(%6_dims, %b_dims) - %35 : Tensor = aten::gt(%data, %b_data) - %36 : Tensor = aten::mul(%mask, %b_mask) - %37 : Tensor = aten::__or__(%dims, %b_dims) - %38 : bool = prim::Bool(%35) - %39 : bool = prim::Constant[value=1]() - %40 : int = prim::Constant[value=1]() - %41 : Tensor = aten::type_as(%cond_mask.2, %cond_data.2) - %data.2 : Tensor = aten::mul(%cond_data.2, %41) - %43 : int = aten::dim(%data.2) - %44 : bool = aten::eq(%43, %40) - %cond_data : Tensor, %cond_mask : Tensor = prim::If(%44) + %32 : Tensor = aten::gt(%data, %b_data) + %33 : Tensor = aten::mul(%mask, %b_mask) + %34 : bool = prim::Constant[value=1]() + %35 : int = prim::Constant[value=1]() + %36 : Tensor = aten::type_as(%cond_mask.2, %cond_data.2) + %data.2 : Tensor = aten::mul(%cond_data.2, %36) + %38 : int = aten::dim(%data.2) + %39 : bool = aten::eq(%38, %35) + %cond_data : Tensor, %cond_mask : Tensor = prim::If(%39) block0() { - %47 : int = aten::dim(%data) - %48 : int = aten::sub(%47, %40) - %data.4 : Tensor = prim::Loop(%48, %39, %data.2) - block0(%50 : int, %51 : Tensor) { - %52 : int = aten::dim(%51) - %data.3 : Tensor = aten::unsqueeze(%51, %52) - -> (%39, %data.3) + %42 : int = aten::dim(%data) + %43 : int = aten::sub(%42, %35) + %data.4 : Tensor = prim::Loop(%43, %34, %data.2) + block0(%45 : int, %46 : Tensor) { + %47 : int = aten::dim(%46) + %data.3 : Tensor = aten::unsqueeze(%46, %47) + -> (%34, %data.3) } %cond_data.1 : Tensor = aten::expand_as(%data.4, %data) %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask) @@ -52,12 +49,13 @@ graph(%a.1_data : Tensor %res_data : Tensor = aten::where(%cond_data, %data, %6_data) %res_mask : Tensor = aten::where(%cond_mask, %mask, %6_mask) %res_dims : Tensor = aten::__or__(%dims, %6_dims) - %59 : int = prim::Constant[value=0]() - %60 : Tensor = aten::mul(%35, %36) - %61 : Tensor = aten::sum(%60) - %62 : Tensor = aten::gt(%61, %59) - %63 : bool = prim::Bool(%62) - -> (%63, %35, %36, %37, %res_data, %res_mask, %res_dims) + %54 : int = prim::Constant[value=0]() + %55 : Tensor = aten::mul(%32, %33) + %56 : Tensor = aten::sum(%55) + %57 : Tensor = aten::gt(%56, %54) + %58 : bool = prim::Bool(%57) + -> (%58, %32, %33, %res_data, %res_mask, %res_dims) } - return (%a, %21, %22); + %59 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%a, %19, %20) + return (%59); } diff --git a/test/expect/TestFuser.test_lstm_cuda-forward.expect b/test/expect/TestFuser.test_lstm_cuda-forward.expect index 2933e1b..55fbd37 100644 --- a/test/expect/TestFuser.test_lstm_cuda-forward.expect +++ b/test/expect/TestFuser.test_lstm_cuda-forward.expect @@ -6,7 +6,8 @@ graph(%x : Float(*, *) %b_ih : Float(*) %b_hh : Float(*)) { %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih) - return (%hy, %cy); + %9 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy) + return (%9); } with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) %1 : Float(*) diff --git a/test/expect/TestFuser.test_lstm_traced_cuda.expect b/test/expect/TestFuser.test_lstm_traced_cuda.expect index 5bd3f4b..3a02e32 100644 --- a/test/expect/TestFuser.test_lstm_traced_cuda.expect +++ b/test/expect/TestFuser.test_lstm_traced_cuda.expect @@ -13,7 +13,8 @@ graph(%input_1 : Float(*, *) %12 : Tensor[] = aten::broadcast_tensors(%11) %13 : Tensor, %14 : Tensor, %15 : Tensor, %16 : Tensor = prim::ListUnpack(%12) %17 : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13) - return (%17, %cy); + %19 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%17, %cy) + return (%19); } with prim::FusionGroup_0 = graph(%0 : Float(*, *) %1 : Tensor diff --git a/test/expect/TestFuser.test_milstm_cuda-forward.expect b/test/expect/TestFuser.test_milstm_cuda-forward.expect index 28c7cd5..dd68b64 100644 --- a/test/expect/TestFuser.test_milstm_cuda-forward.expect +++ b/test/expect/TestFuser.test_milstm_cuda-forward.expect @@ -8,7 +8,8 @@ graph(%x : Float(*, *) %beta_h : Float(*) %bias : Float(*)) { %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %bias, %beta_h, %beta_i, %alpha, %hx, %w_hh, %x, %w_ih) - return (%hy, %cy); + %11 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy) + return (%11); } with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) %1 : Float(*) diff --git a/test/expect/TestJit.test_cu_escaped_number.expect b/test/expect/TestJit.test_cu_escaped_number.expect index cdd788b..49fc2f6 100644 --- a/test/expect/TestJit.test_cu_escaped_number.expect +++ b/test/expect/TestJit.test_cu_escaped_number.expect @@ -1,3 +1,4 @@ def graph(self, - a: Tensor) -> Tuple[]: + a: Tensor) -> None: print("hi\016") + return None diff --git a/test/expect/TestJit.test_pretty_printer-loop_use_test.expect b/test/expect/TestJit.test_pretty_printer-loop_use_test.expect index 7442d34..b97b33b 100644 --- a/test/expect/TestJit.test_pretty_printer-loop_use_test.expect +++ b/test/expect/TestJit.test_pretty_printer-loop_use_test.expect @@ -7,4 +7,4 @@ def graph(self, while _0: y_2 = torch.add_(y, 1, 1) _0, y, z = bool(torch.lt(y_2, 8)), y_2, x - return x, z + return (x, z) diff --git a/test/expect/TestJit.test_pretty_printer-print_weird_test.expect b/test/expect/TestJit.test_pretty_printer-print_weird_test.expect index f6f2da4..448b981 100644 --- a/test/expect/TestJit.test_pretty_printer-print_weird_test.expect +++ b/test/expect/TestJit.test_pretty_printer-print_weird_test.expect @@ -1,3 +1,4 @@ def graph(self, - y: Tensor) -> Tuple[]: + y: Tensor) -> None: print("hi\016") + return None diff --git a/test/expect/TestJit.test_repeated_output.expect b/test/expect/TestJit.test_repeated_output.expect index b28fa6c..eb5b0e2 100644 --- a/test/expect/TestJit.test_repeated_output.expect +++ b/test/expect/TestJit.test_repeated_output.expect @@ -2,5 +2,6 @@ graph(%a : Tensor %b : Tensor) { %2 : int = prim::Constant[value=1]() %3 : Tensor = aten::add(%a, %b, %2) - return (%3, %3); + %4 : (Tensor, Tensor) = prim::TupleConstruct(%3, %3) + return (%4); } diff --git a/test/expect/TestJit.test_trace_tuple.expect b/test/expect/TestJit.test_trace_tuple.expect index 249e266..a46fc5f 100644 --- a/test/expect/TestJit.test_trace_tuple.expect +++ b/test/expect/TestJit.test_trace_tuple.expect @@ -4,5 +4,6 @@ graph(%x : Double(2, 2) %4 : Double(2, 2) = aten::mul(%x, %3) %5 : Double(2, 2) = aten::mul(%x, %2) %6 : (Double(2, 2), Double(2, 2)) = prim::TupleConstruct(%4, %5) - return (%x, %6); + %7 : (Double(2, 2), (Double(2, 2), Double(2, 2))) = prim::TupleConstruct(%x, %6) + return (%7); } diff --git a/test/expect/TestScript.test_augmented_assign.expect b/test/expect/TestScript.test_augmented_assign.expect index c40d3e5..94aeb9f 100644 --- a/test/expect/TestScript.test_augmented_assign.expect +++ b/test/expect/TestScript.test_augmented_assign.expect @@ -5,5 +5,6 @@ graph(%a.1 : Tensor %a.3 : Tensor = aten::sub_(%a.2, %b, %2) %a.4 : Tensor = aten::div_(%a.3, %b) %a : Tensor = aten::mul_(%a.4, %b) - return (%a, %b); + %7 : (Tensor, Tensor) = prim::TupleConstruct(%a, %b) + return (%7); } diff --git a/test/expect/TestScript.test_constant_pooling.expect b/test/expect/TestScript.test_constant_pooling.expect index 7f7340e..29e11ac 100644 --- a/test/expect/TestScript.test_constant_pooling.expect +++ b/test/expect/TestScript.test_constant_pooling.expect @@ -29,5 +29,6 @@ graph(%cond : Tensor) { = prim::Print(%d, %e, %d, %5, %y.4, %5) -> (%c.1, %y.4) } - return (%a, %3, %c, %5, %y); + %19 : (int, int, int, Tensor, Tensor) = prim::TupleConstruct(%a, %3, %c, %5, %y) + return (%19); } diff --git a/test/expect/TestScript.test_if_supertype.expect b/test/expect/TestScript.test_if_supertype.expect index 836ac75..0d516a9 100644 --- a/test/expect/TestScript.test_if_supertype.expect +++ b/test/expect/TestScript.test_if_supertype.expect @@ -9,5 +9,6 @@ graph(%x.1 : Float(*, *) block1() { -> (%x.1, %x.1, %y.1) } - return (%x, %y, %z); + %7 : (Float(*, *), Tensor, Tensor) = prim::TupleConstruct(%x, %y, %z) + return (%7); } diff --git a/test/expect/TestScript.test_mutable_dce_graph_input.expect b/test/expect/TestScript.test_mutable_dce_graph_input.expect index 054ec40..0ac2218 100644 --- a/test/expect/TestScript.test_mutable_dce_graph_input.expect +++ b/test/expect/TestScript.test_mutable_dce_graph_input.expect @@ -8,5 +8,6 @@ graph(%a.1 : Tensor) { %7 : int[] = prim::ListConstruct(%5, %6) %8 : Tensor = aten::rand(%7, %4, %3, %2) %a : Tensor = aten::add_(%a.1, %8, %1) - return (); + %10 : None = prim::None() + return (%10); } diff --git a/test/expect/TestScript.test_python_frontend.expect b/test/expect/TestScript.test_python_frontend.expect index 0e2a47a..649b714 100644 --- a/test/expect/TestScript.test_python_frontend.expect +++ b/test/expect/TestScript.test_python_frontend.expect @@ -69,5 +69,4 @@ (assert (eq (const 1) (const 1)) (option (string_literal hello))) - (return - (list (variable (ident x)))))) + (return (variable (ident x))))) diff --git a/test/expect/TestScript.test_tuple_indexing.expect b/test/expect/TestScript.test_tuple_indexing.expect index 47d845a..5854e9b 100644 --- a/test/expect/TestScript.test_tuple_indexing.expect +++ b/test/expect/TestScript.test_tuple_indexing.expect @@ -14,5 +14,6 @@ graph(%a : Tensor) { } %8 : int = prim::TupleIndex[index=0](%b) %9 : int = prim::TupleIndex[index=1](%b) - return (%8, %9); + %10 : (int, int) = prim::TupleConstruct(%8, %9) + return (%10); } diff --git a/test/expect/TestScript.test_tuple_slicing.expect b/test/expect/TestScript.test_tuple_slicing.expect index 1628b7b..e1a8749 100644 --- a/test/expect/TestScript.test_tuple_slicing.expect +++ b/test/expect/TestScript.test_tuple_slicing.expect @@ -15,6 +15,5 @@ graph(%a : Tensor) { } %c : (int, int, int, int) = prim::TupleSlice[beg=0, end=4](%b) %e : (int, int) = prim::TupleSlice[beg=1, end=3](%c) - %11 : int, %12 : int = prim::TupleUnpack(%e) - return (%11, %12); + return (%e); } diff --git a/test/test_jit.py b/test/test_jit.py index 8c1acf0..89ba4cc 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -182,8 +182,12 @@ def get_execution_plan(graph_executor_state): def get_grad_executor(plan_state, diff_graph_idx=None): - if diff_graph_idx is None and len(list(plan_state.graph.nodes())) != 1: - raise RuntimeError("Can't get a grad_executor for a non-differentiable graph") + if diff_graph_idx is None: + nodes = list(plan_state.graph.nodes()) + if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"): + pass + else: + raise RuntimeError("Can't get a grad_executor for a non-differentiable graph") grad_executors = list(plan_state.code.grad_executors()) return grad_executors[diff_graph_idx or 0] @@ -262,7 +266,6 @@ class JitTestCase(TestCase): ppv = "op_version_set = 0\n{}".format(pp) sm = copy_structure_and_params(module) torch._C._jit_import_methods(sm, ppv, constant_table) - pp2, _ = sm._python_print() if pp != pp2: self.assertMultiLineEqual(pp, pp2) @@ -553,6 +556,14 @@ class TestJit(JitTestCase): with self.assertRaisesRegex(pickle.PickleError, "not supported"): torch.save(FooToPickle(), "will_fail") + def test_single_tuple_trace(self): + x = torch.tensor(2.) + + def f2(x): + return (x,) + jit_f2 = torch.jit.trace(f2, x) + assert f2(x) == jit_f2(x) # fails + @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") def test_restore_device_cuda(self): class MyModule(torch.jit.ScriptModule): @@ -4272,7 +4283,7 @@ a") self.checkScript(one_return, [a], optimize=True) self.checkScript(multiple_returns, [a], optimize=True) - with self.assertRaisesRegex(RuntimeError, "Expected 1 return value"): + with self.assertRaisesRegex(RuntimeError, "but is actually of type None"): @torch.jit.script def no_return_bad_annotation(a): # type: (Tensor) -> Tensor @@ -7549,10 +7560,7 @@ a") self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True) tuple_comp = torch.jit.script(tuple_index) self.assertExpectedGraph(tuple_comp.graph) - self.run_pass('lower_all_tuples', tuple_comp.graph) - m = torch.jit.ScriptModule() - m._create_method_from_graph("forward", tuple_comp.graph) - self.assertEqual(m(torch.tensor(1)), (1, 2)) + self.assertEqual(tuple_comp(torch.tensor(1)), (1, 2)) with self.assertRaisesRegex(RuntimeError, "tuple indices must be integer constants"): @torch.jit.script @@ -7596,21 +7604,19 @@ a") return e self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True) + tuple_graph = torch.jit.script(tuple_slice) + self.assertExpectedGraph(tuple_graph.graph) + self.run_pass('lower_all_tuples', tuple_graph.graph) + self.assertTrue('Tuple' not in str(tuple_graph.graph)) tuple_comp = torch.jit.script(tuple_slice) - self.assertExpectedGraph(tuple_comp.graph) - self.run_pass('lower_all_tuples', tuple_comp.graph) - self.assertTrue('Tuple' not in str(tuple_comp.graph)) - m = torch.jit.ScriptModule() - m._create_method_from_graph("forward", tuple_comp.graph) - self.assertEqual(m(torch.tensor(1)), (2, 3)) + self.assertEqual(tuple_comp(torch.tensor(1)), (2, 3)) @torch.jit.script def test_indexing_end_out_of_bounds(): c = (1, 2) return c[2:10] - # output is None in script and () in python - self.assertEqual(test_indexing_end_out_of_bounds(), None) + self.assertEqual(test_indexing_end_out_of_bounds(), ()) def test_unwrap_optional_builtin(self): def test(x): @@ -7665,9 +7671,7 @@ a") self.assertExpected(sm.__getattr__('forward').pretty_print_schema()) def test_annotated_script_fn_return_mismatch(self): - with self.assertRaisesRegex(RuntimeError, r"Return value at position 0 was annotated as " - r"having type \(Tensor, Tensor\) but is " - r"actually of type Tensor"): + with self.assertRaisesRegex(RuntimeError, "but is actually of type"): @torch.jit.script def return_tup(x): # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor] @@ -9353,6 +9357,14 @@ class TestPytorchExportModes(JitTestCase): export_type=torch.onnx.ExportTypes.DIRECTORY) shutil.rmtree(d) + def test_onnx_multiple_return(self): + @torch.jit.script + def foo(a): + return (a, a) + f = io.BytesIO() + x = torch.ones(3) + torch.onnx._export(foo, (x,), f, example_outputs=(x, x)) + @skipIfRocm @skipIfNoLapack def test_aten_fallback(self): @@ -9571,19 +9583,10 @@ def check_alias_annotation(method_name, args, kwargs): def check_output_types(self, func, ref_outputs, args, kwargs): graph = getattr(func, 'last_graph', None) - if not isinstance(ref_outputs, tuple): - ref_outputs = (ref_outputs,) types = [o.type() for o in graph.outputs()] - self.assertEqual(len(types), len(ref_outputs)) - for i, (t, ref_out) in enumerate(zip(types, ref_outputs)): - if isinstance(ref_out, list): - assert len(ref_out) > 0 - elem = ref_out[0] - assert isinstance(elem, torch.Tensor) - self.assertTrue(t.isSubtypeOf(torch._C.ListType.ofTensors())) - else: - ref_type = torch._C.Type.inferFrom(ref_out) - self.assertTrue(ref_type.isSubtypeOf(t)) + self.assertTrue(len(types) == 1) + t = types[0] + torch._C._jit_assert_is_instance(ref_outputs, t) def check_against_reference(self, func, reference_func, args, kwargs=None, @@ -10289,7 +10292,7 @@ class TestFuser(JitTestCase): expected1, expected2 = f(x, y) self.assertEqual(result1, expected1) self.assertEqual(result2, expected2) - self.assertAllFused(script_f.graph_for(x, y)) + self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) # TODO: This test seems dead @unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows") @@ -10372,7 +10375,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase): graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1) - self.assertGraphSize(graph, 2) + self.assertGraphSize(graph, 3) self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2) def test_merges_without_cycles(self): @@ -10404,7 +10407,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase): graph = self._perform_ad_subgraph_slicing(fn, 2, 2) - self.assertGraphSize(graph, 1) + self.assertGraphSize(graph, 2) self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1) def test_does_not_create_cycles(self): @@ -10434,7 +10437,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase): graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1) - self.assertGraphSize(graph, 2) + self.assertGraphSize(graph, 3) self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1) def test_merges_down(self): @@ -10449,7 +10452,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase): graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1) - self.assertGraphSize(graph, 2) + self.assertGraphSize(graph, 3) self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1) def test_respects_lexical_scoping(self): diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 77c9d54..f0e35f6 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -356,6 +356,9 @@ void initJITBindings(PyObject *module) { // TODO: this is a fake stub }); + m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) { + toIValue(obj, type); + }); initPythonIRBindings(module); tracer::initPythonTracerBindings(module); diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 3930b2d..27e7594 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -562,6 +562,10 @@ struct PythonPrintPass { return; switch (node->kind()) { case prim::Return: + if (enforce_importable_ && node->inputs().size() != 1) { + throw script::ErrorReport(node->getSourceLocation()) + << "Exportable methods must have a single return value. Normal use of ScriptMethods should enforce this."; + } if (node->inputs().size() > 0) { indent(); out << "return "; @@ -813,7 +817,17 @@ struct PythonPrintPass { return out; } - void printDefaultValue(std::ostream& stmt, const IValue& value) { + void printDefaultValue(const TypePtr& typ, std::ostream& stmt, const IValue& value) { + // xxx - many weak script modules store default values for broadcasting lists + // that are not actually the same type as the argument. We can only serialize + // default values that will implicitly convert to their declared return type + // since we do not need to serialize these built-in modules with their defaults, + // we just drop them for now. + if (typ->kind() == ListType::Kind && + (value.isInt() || value.isDouble() || value.isBool())) { + return; + } + stmt << "="; if (value.isTensor() && !value.toTensor().defined()) { // XXX - because undefined tensors are not stored as None, we need special handling. // otherwise they get printed as CONSTANTS.c0 and then cannot be recreated because @@ -856,8 +870,7 @@ struct PythonPrintPass { if (defaults_offset != defaults.end()) { const c10::optional& def = *defaults_offset++; if (def) { - out << "="; - printDefaultValue(out, *def); + printDefaultValue(input->type(), out, *def); } } } diff --git a/torch/csrc/jit/passes/to_batch.cpp b/torch/csrc/jit/passes/to_batch.cpp index 8877806..37992b0 100644 --- a/torch/csrc/jit/passes/to_batch.cpp +++ b/torch/csrc/jit/passes/to_batch.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace torch { namespace jit { @@ -19,6 +20,15 @@ std::shared_ptr ToBatch::getBatchOperator(const std::string& name, int64_ throw std::runtime_error("function " + name + " with " + std::to_string(num_inputs) + " inputs is not supported in batched tensor yet"); } +std::vector inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef inputs) { + auto outputs = script::inlineCallTo(g, callee, inputs); + if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) { + auto tc = script::createTupleUnpack(outputs.at(0)); + outputs = std::vector(tc.begin(), tc.end()); + } + return outputs; +} + // replace aten operator node with BatchTensor operator graph void ToBatch::visitAten(Node* n, Block* block, Block* res_block){ auto res_graph = res_block->owningGraph(); @@ -44,13 +54,13 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){ } auto batch_graph = getBatchOperator(func_name, new_inputs.size()); - auto outputs = script::inlineCallTo(*res_block->owningGraph(), *batch_graph, new_inputs); + auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs); // Assume all outputs from inlined operator implementation are in the triple form batched tensor or just a single non-tensor. if (outputs.size() == 1) { // if previous output is scalar, transform new output back to scalar from dynamic TypePtr orig_type = n->outputs()[0]->type(); - if (orig_type != outputs[0]->type()){ + if (!orig_type->isSubtypeOf(outputs[0]->type())) { Symbol op; if (orig_type == IntType::get()) { op = prim::Int; @@ -88,7 +98,7 @@ void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block){ auto res_graph = res_block->owningGraph(); auto* r_node = res_graph->createClone(n, rn_fn); res_block->appendNode(r_node); - auto outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("batch_from_scalar_tensor"), r_node->outputs()); + auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("batch_from_scalar_tensor"), r_node->outputs()); batch_map[n->output()] = outputs; } @@ -229,7 +239,7 @@ void ToBatch::visitIf(Node* n, Block* block, Block* res_block){ inputs.insert(inputs.end(), if_output.begin(), if_output.end()); auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]); inputs.insert(inputs.end(), else_output.begin(), else_output.end()); - auto outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("where", inputs.size()), inputs); + auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("where", inputs.size()), inputs); batch_map[n->outputs()[i]] = outputs; } } @@ -338,7 +348,7 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){ } if(cond_is_tensor){ auto cond = batch_map.at(n->inputs()[1]); - auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond); + auto cond_any = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond); rn_env[n->inputs()[1]] =res_graph->insert(prim::Bool, {cond_any[0]}); } for(size_t i = 2; i < n->inputs().size(); i++){ @@ -400,7 +410,7 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){ for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){ inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]); } - outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("where"), inputs); + outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("where"), inputs); } else{ for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){ @@ -408,7 +418,7 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){ } auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]); inputs.insert(inputs.end(), data.begin(), data.end()); - outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("update"), inputs); + outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("update"), inputs); } batch_map[n->outputs()[i]] = outputs; for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){ @@ -419,7 +429,7 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){ // update loop conditions if(cond_is_tensor){ auto cond = batch_map.at(n->blocks()[0]->outputs()[0]); - auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond); + auto cond_any = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond); auto to_bool_output = res_graph->insert(prim::Bool, {cond_any[0]}); loop_block->insertOutput(0, to_bool_output); for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){ @@ -513,17 +523,33 @@ void ToBatch::toBatch(Block* block, Block* res_block) { } } -std::shared_ptr to_batch_graph(std::shared_ptr& graph){ +std::shared_ptr to_batch_graph(std::shared_ptr graph) { + // lower the tuple before the pass + if (graph->outputs().at(0)->type()->kind() == TupleType::Kind) { + graph = graph->copy(); + auto outs = script::createTupleUnpack(graph->outputs().at(0)); + graph->eraseOutput(0); + for(auto o : outs) + graph->registerOutput(o); + EliminateDeadCode(graph->block()); + } std::shared_ptr res_graph = std::make_shared(); ToBatch to_batch; to_batch.toBatch(graph->block(), res_graph->block()); + // methods should only have a single output, so we pack everything into a tuple + auto tup = res_graph->insertNode(res_graph->createTuple(res_graph->outputs())); + while (res_graph->outputs().size() > 0) + res_graph->eraseOutput(res_graph->outputs().size() - 1); + res_graph->registerOutput(tup->output()); + EliminateDeadCode(res_graph->block()); + return res_graph; } void initRegisterBatchOpsBindings(PyObject* module) { auto m = py::handle(module).cast(); - m.def("to_batch_graph", &to_batch_graph); + m.def("to_batch_graph", to_batch_graph); m.def("register_batch_operator", [](std::string name, std::shared_ptr graph){ ToBatch::batch_operator_table[name].push_back(graph); }); diff --git a/torch/csrc/jit/passes/to_batch.h b/torch/csrc/jit/passes/to_batch.h index 4c2a098..959f265 100644 --- a/torch/csrc/jit/passes/to_batch.h +++ b/torch/csrc/jit/passes/to_batch.h @@ -33,6 +33,6 @@ public: TORCH_API void toBatch(Block* block, Block* res_block); }; -TORCH_API std::shared_ptr to_batch_graph(std::shared_ptr& graph); +TORCH_API std::shared_ptr to_batch_graph(std::shared_ptr graph); TORCH_API void initRegisterBatchOpsBindings(PyObject* module); }} diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 2b6c336..5089ada 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -408,7 +408,7 @@ inline py::object invokeScriptMethodFromPython( AutoNoGIL no_gil_guard; method.run(stack); } - return createPyObjectForStack(std::move(stack)); + return toPyObject(std::move(stack.back())); } inline py::object invokeOperatorFromPython( diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index c34335a..5a39fe3 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -449,8 +449,7 @@ void initPythonIRBindings(PyObject * module_) { }) .def("isSubtypeOf", [](std::shared_ptr& self, std::shared_ptr other) { return self->isSubtypeOf(other); - }) - .def_static("inferFrom", c10::inferTypeFrom); + }); py::class_>(m, "NumberType") .def_static("get", &NumberType::get); diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp index a83b645..30210a1 100644 --- a/torch/csrc/jit/python_tracer.cpp +++ b/torch/csrc/jit/python_tracer.cpp @@ -60,10 +60,7 @@ std::shared_ptr createGraphByTracing( AT_ERROR("The traced function didn't return any values! Side-effects are not " "captured in traces, so it would be a no-op."); } - if (!PyTuple_Check(out.ptr())) { - out = py::make_tuple(out); - } - tracer::exit(toStack(out)); + tracer::exit({toIValue(out)}); auto graph = enter_info.first->graph; EliminateDeadCode(graph); LowerSimpleTuples(graph); diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 5b018a5..ad6c186 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -448,13 +448,32 @@ Value* tryConvertToType( const TypePtr& concrete_type, Value* value, bool allow_conversions) { - // Allow homogeneous tuples to be casted implicitly to lists of appropriate - // types - if (convertibleToList(value->type(), unwrapOptional(concrete_type)) && - value->type()->kind() == TypeKind::TupleType) { - auto unpacked = createTupleUnpack(value); - auto elem_type = unwrapOptional(concrete_type)->expect()->getElementType(); - value = graph.insertNode(graph.createList(elem_type, unpacked))->output(); + + if (auto value_tuple = value->type()->cast()) { + // Allow homogeneous tuples to be casted implicitly to lists of appropriate + // types + if (convertibleToList(value->type(), unwrapOptional(concrete_type))) { + auto unpacked = createTupleUnpack(value); + auto elem_type = unwrapOptional(concrete_type)->expect()->getElementType(); + value = graph.insertNode(graph.createList(elem_type, unpacked))->output(); + } + // inductively apply implicit conversions to tuples + if (auto concrete_tuple = concrete_type->cast()) { + if (!value_tuple->isSubtypeOf(concrete_tuple) && + concrete_tuple->elements().size() == value_tuple->elements().size()) { + auto unpacked = createTupleUnpack(value); + std::vector converted; + for (size_t i = 0; i < concrete_tuple->elements().size(); ++i) { + converted.emplace_back(tryConvertToType( + loc, + graph, + concrete_tuple->elements().at(i), + unpacked.at(i), + allow_conversions)); + } + value = graph.insertNode(graph.createTuple(converted))->output(); + } + } } if (value->type()->isSubtypeOf(NoneType::get()) && !concrete_type->isSubtypeOf(NoneType::get())){ @@ -834,9 +853,9 @@ struct to_ir { if (self && def.decl().params().size() == 0) { throw ErrorReport(def.decl().params().range()) << "methods must have a self argument"; } - auto schema = extractSchemaFromDef(def); std::vector arguments = emitFormalArguments(self, schema); + // body auto stmts = def.statements(); auto stmts_begin = stmts.begin(); @@ -847,7 +866,8 @@ struct to_ir { return_stmt = Return(*stmts_end); } emitStatements(stmts_begin, stmts_end); - std::vector returns = emitReturn(return_stmt, schema); + std::vector returns = {emitReturn( + return_stmt ? return_stmt->range() : def.range(), return_stmt, schema)}; method.setSchema({def.name().name(), std::move(arguments), std::move(returns)}); // remove any uses of tuples that we inserted that are not needed @@ -895,7 +915,7 @@ private: Decl::create(r, List::create(r, {}), Maybe::create(r, tuple_type)); auto tuple_expr = TupleLiteral::create(r, List::create(r, default_exprs)); - auto ret = Return::create(r, List::create(r, { tuple_expr })); + auto ret = Return::create(r, tuple_expr); auto def = Def::create( r, Ident::create(r, "defaults"), @@ -903,8 +923,9 @@ private: List::create(r, {ret})); auto m = std::make_shared(); defineMethodsInModule(m, {def}, {resolver}, nullptr); - m->get_method("defaults").run(default_values); - return default_values; + Stack stack; + m->get_method("defaults").run(stack); + return stack.at(0).toTuple()->elements(); } std::vector parseArgsFromDecl(const Decl& decl) { @@ -957,44 +978,29 @@ private: return retval; } - std::vector parseReturnsFromDecl(const Decl& decl) { - JIT_ASSERT(decl.return_type().present()); + std::vector parseReturnFromDecl(const Decl& decl) { + // we represent no annoation on a return type as having no values in the + // schema's return() list + // in emitReturn we take the actual return value to be the value of the return + // statement if no one was provided here + if(!decl.return_type().present()) + return {}; + if (handleBroadcastList(decl.return_type().get())) throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type"; auto parsed_type = parseTypeFromExpr(decl.return_type().get()); - if (auto tuple_type = parsed_type->cast()) { - // Flatten a single return type of type Tuple into its constituent types - std::vector retval; - for (auto type_ptr : tuple_type->elements()) { - retval.emplace_back( - "", - type_ptr, - /*N =*/c10::nullopt, - /*default_value =*/c10::nullopt, - /*kwarg_only =*/false); - } - return retval; - } else { - return {Argument( - "", - parsed_type, - /*N =*/c10::nullopt, - /*default_value =*/c10::nullopt, - /*kwarg_only =*/false)}; - } + return {Argument( + "", + parsed_type, + /*N =*/c10::nullopt, + /*default_value =*/c10::nullopt, + /*kwarg_only =*/false)}; } FunctionSchema extractSchemaFromDef(const Def &def) { auto name = def.name().name(); std::vector args = parseArgsFromDecl(def.decl()); - std::vector returns; - bool is_varret; - if (def.decl().return_type().present()) { - returns = parseReturnsFromDecl(def.decl()); - is_varret = false; - } else { - is_varret = true; - } - return FunctionSchema(name, args, returns, false, is_varret); + std::vector returns = parseReturnFromDecl(def.decl()); + return FunctionSchema(name, std::move(args), std::move(returns), false, false); } std::vector emitFormalArguments(const SugaredValuePtr& self, const FunctionSchema& schema) { @@ -1030,50 +1036,27 @@ private: } return arguments; } - std::vector emitReturn(c10::optional return_stmt_, const FunctionSchema& schema) { + + Argument emitReturn(const SourceRange& range, c10::optional return_stmt, const FunctionSchema& schema) { + JIT_ASSERT(schema.returns().size() <= 1); // outputs - std::vector returns; - if (return_stmt_) { - auto return_stmt = *return_stmt_; - auto results = getValues(return_stmt.values(), true); - // a single return value that is a tuple expands in place: - // return a - if (return_stmt.values().size() == 1 && results.size() == 1) { - auto result = results.at(0); - if(result->type()->cast()) { - results = createTupleUnpack(result).vec(); - } - } - if (!schema.is_varret() && schema.returns().size() != results.size()) { - throw ErrorReport(def.range()) << "Number of type annotations for function" - << " return (" << schema.returns().size() << ") does not match" - << " the number of returns from the function (" << results.size() << ")!"; - } - const auto& range = return_stmt.range(); - size_t return_type_idx = 0; - for (auto r : results) { - TypePtr type = DynamicType::get(); - if (!schema.is_varret()) { - type = schema.returns().at(return_type_idx).type(); - r = tryConvertToType(range, *graph, type, r, /*allow_conversions=*/false); - if (!r->type()->isSubtypeOf(type)) { - throw ErrorReport(return_stmt.range()) << "Return value at position " - << return_type_idx << " was annotated as having type " << type->str() - << " but is actually of type " << r->type()->str(); - } - return_type_idx++; - } - graph->registerOutput(r); - returns.emplace_back("", type); - } - } else if (schema.returns().size() > 0) { - // schema has returns but there's no return nodes in graph - throw ErrorReport() << "Expected " << schema.returns().size() - << " return value" - << (schema.returns().size() > 1 ? "s" : "") - << " but found no return statement"; + Value* result = return_stmt ? emitExpr(return_stmt->expr()) + : graph->insertConstant(IValue(), range); + TypePtr result_type = schema.returns().size() > 0 + ? schema.returns().at(0).type() + : result->type(); + + if (return_stmt) { + result = tryConvertToType( + range, *graph, result_type, result, /*allow_conversions=*/true); + } + + if (!result->type()->isSubtypeOf(result_type)) { + throw ErrorReport(range) << "Return value was annotated as having type " << result_type->python_str() + << " but is actually of type " << result->type()->python_str(); } - return returns; + graph->registerOutput(result); + return Argument("", result_type); } void emitStatements(const List& statements) { return emitStatements(statements.begin(), statements.end()); @@ -2720,6 +2703,7 @@ const std::unordered_map &ident_to_type_lut() { // technically this is not a python type but we need it when // parsing serialized methods that use implicit converions to Scalar {"number", NumberType::get()}, + {"None", NoneType::get()}, }; return map; } @@ -2769,12 +2753,15 @@ c10::optional parseBaseTypeName(const Expr& expr) { case TK_VAR: { return Var(expr).name().name(); } + case TK_NONE: { + return "None"; + } case '.': { auto select = Select(expr); const std::string& name = select.selector().name(); if (isTorch(select.value()) && name == "Tensor") return "Tensor"; - } + } break; } return at::nullopt; } diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h index 7abe842..1730a88 100644 --- a/torch/csrc/jit/script/compiler.h +++ b/torch/csrc/jit/script/compiler.h @@ -277,6 +277,8 @@ TORCH_API c10::optional findInputWithName( const std::string& name, at::ArrayRef kwargs); +TORCH_API at::ArrayRef createTupleUnpack(Value* v); + } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 236a727..e0d51bf 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -525,7 +525,7 @@ void initJitScriptBindings(PyObject* module) { const std::string& script, ResolutionCallback rcb, bool has_self) { auto self = has_self ? std::make_shared(m) : nullptr; - return defineMethodsInModule(m, script, pythonResolver(rcb), self); + defineMethodsInModule(m, script, pythonResolver(rcb), self); }) .def("_create_methods", [](std::shared_ptr m, const std::vector& defs, @@ -646,7 +646,7 @@ void initJitScriptBindings(PyObject* module) { if (self.find_method("forward")) { Method & m = self.get_method("forward"); return m.graph_for( - createStackForSchema(m.getSchema(), tuple_slice(std::move(args), 1), std::move(kwargs))); + createStackForSchema(m.getSchema(), tuple_slice(std::move(args), 1), kwargs)); } throw std::runtime_error("Attempted to call graph_for on a Module without a compiled forward()"); }) @@ -708,7 +708,7 @@ void initJitScriptBindings(PyObject* module) { .def("graph_for", [](py::args args, py::kwargs kwargs) { // see: [pybind11 varargs] Method& self = py::cast(args[0]); - return self.graph_for(createStackForSchema(self.getSchema(), tuple_slice(std::move(args), 1), std::move(kwargs))); + return self.graph_for(createStackForSchema(self.getSchema(), tuple_slice(std::move(args), 1), kwargs)); }) .def("debug_disable_autodiff_subgraph_inlining", &Method::debugDisableAutodiffSubgraphInlining) .def("schema", &Method::getSchema) diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 0b57080..aeed2d8 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -69,9 +69,6 @@ struct Method { IValue operator()(std::vector stack) { checkInputsAgainstSchema(stack); run(stack); - if (stack.size() != 1) { - return Tuple::create(std::move(stack)); - } return stack.front(); } @@ -144,12 +141,18 @@ struct Method { auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); retval->inputs()[i]->setType(type); } - JIT_ASSERT(retval->outputs().size() == outputs.size()); + at::ArrayRef output_values = retval->outputs(); + // patch this to still work if we are returning a tuple of multiple values + if (output_values.at(0)->type()->kind() == TupleType::Kind) { + JIT_ASSERT(output_values.at(0)->node()->kind()== prim::TupleConstruct); + output_values = output_values.at(0)->node()->inputs(); + } + JIT_ASSERT(output_values.size() == outputs.size()); for (size_t i=0; i < retval->outputs().size(); ++i) { auto scalar_type = outputs[i].type().scalarType(); auto sizes = outputs[i].sizes(); auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); - retval->outputs()[i]->setType(type); + output_values[i]->setType(type); } return retval; } @@ -213,7 +216,10 @@ private: } GraphExecutor& get_executor() { - std::call_once(executor_init, [&]{ + std::call_once(executor_init, [&] { + AT_CHECK( + graph()->outputs().size() == 1, + "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs"); executor = GraphExecutor(graph(), optimize); }); return executor; @@ -231,7 +237,10 @@ private: for (size_t pos = 0; pos < schema.arguments().size(); ++pos) { const auto& argument = schema.arguments()[pos]; if (pos < inputs.size()) { - const TypePtr inputType = inferTypeFrom(inputs[pos]); + // XXX - this fails to handle generic aggregates + // and should be replaced with a function isSubvalueOf(ivalue, type) + // That asks if the specific value is a valid instance of type. + const TypePtr inputType = incompleteInferTypeFrom(inputs[pos]); AT_CHECK(inputType->isSubtypeOf(argument.type()), "Expected value of type ", *argument.type(), " for argument '", argument.name(), diff --git a/torch/csrc/jit/script/parser.cpp b/torch/csrc/jit/script/parser.cpp index e92acec..65201af 100644 --- a/torch/csrc/jit/script/parser.cpp +++ b/torch/csrc/jit/script/parser.cpp @@ -365,9 +365,10 @@ struct ParserImpl { } case TK_RETURN: { auto range = L.next().range; - // XXX: TK_NEWLINE makes it accept an empty list - auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &ParserImpl::parseExp); - return Return::create(range, values); + Expr value = L.cur().kind != TK_NEWLINE ? parseExpOrExpTuple() + : Expr(c(TK_NONE, range, {})); + L.expect(TK_NEWLINE); + return Return::create(range, value); } case TK_RAISE: { auto range = L.next().range; @@ -461,20 +462,24 @@ struct ParserImpl { } while(!L.nextIf(TK_DEDENT)); return c(TK_LIST, r, std::move(stmts)); } - Decl parseDecl() { - auto paramlist = parseList('(', ',', ')', &ParserImpl::parseParam); - // Parse return type annotation - TreeRef return_type; + + Maybe parseReturnAnnotation() { if (L.nextIf(TK_ARROW)) { // Exactly one expression for return type annotation auto return_type_range = L.cur().range; - return_type = Maybe::create(return_type_range, parseExp()); + return Maybe::create(return_type_range, parseExp()); } else { - // Default to returning single tensor. TODO: better sentinel value? - return_type = Maybe::create(L.cur().range); + return Maybe::create(L.cur().range); } + } + + Decl parseDecl() { + auto paramlist = parseList('(', ',', ')', &ParserImpl::parseParam); + // Parse return type annotation + TreeRef return_type; + Maybe return_annotation = parseReturnAnnotation(); L.expect(':'); - return Decl::create(paramlist.range(), List(paramlist), Maybe(return_type)); + return Decl::create(paramlist.range(), List(paramlist), return_annotation); } TreeRef parseFunction(bool is_method) { diff --git a/torch/csrc/jit/script/python_tree_views.cpp b/torch/csrc/jit/script/python_tree_views.cpp index 8d2c39d..d42fc0e 100644 --- a/torch/csrc/jit/script/python_tree_views.cpp +++ b/torch/csrc/jit/script/python_tree_views.cpp @@ -132,8 +132,8 @@ void initTreeViewBindings(PyObject *module) { return AugAssign::create(r, lhs, kind, rhs); })); py::class_(m, "Return") - .def(py::init([](const SourceRange& range, std::vector values) { - return Return::create(range, wrap_list(range, std::move(values))); + .def(py::init([](const SourceRange& range, Expr* value) { + return Return::create(range, value ? *value : Expr(Compound::create(TK_NONE, range, {}))); })); py::class_(m, "Raise") .def(py::init([](const SourceRange& range, Expr *expr) { diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h index 270e1aa..6a8f7b0 100644 --- a/torch/csrc/jit/script/tree_views.h +++ b/torch/csrc/jit/script/tree_views.h @@ -497,11 +497,11 @@ struct Return : public Stmt { explicit Return(const TreeRef& tree) : Stmt(tree) { tree_->match(TK_RETURN); } - List values() const { - return List(subtree(0)); + Expr expr() const { + return Expr(subtree(0)); } - static Return create(const SourceRange& range, const List& values) { - return Return(Compound::create(TK_RETURN, range, {values})); + static Return create(const SourceRange& range, const Expr& value) { + return Return(Compound::create(TK_RETURN, range, {value})); } }; diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index 691a1d9..15de8c7 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -162,7 +162,7 @@ inline std::pair, Stack> enter(Stack inputs) { } }; for (IValue& input : inputs) { - input = add_input(input, inferTypeFrom(input), state->graph->addInput()); + input = add_input(input, incompleteInferTypeFrom(input), state->graph->addInput()); } return std::make_pair(state, inputs); } diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index cd885c5..2c69d82 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -253,8 +253,7 @@ class StmtBuilder(Builder): @staticmethod def build_Return(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("return")) - values = (stmt.value,) if not isinstance(stmt.value, ast.Tuple) else stmt.value.elts - return Return(r, [build_expr(ctx, val) for val in values if val is not None]) + return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value)) @staticmethod def build_Raise(ctx, stmt): -- 2.7.4