Method returns a single argument (#15289)
authorZachary DeVito <zdevito@fb.com>
Tue, 18 Dec 2018 18:27:26 +0000 (10:27 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 18 Dec 2018 18:44:09 +0000 (10:44 -0800)
Summary:
This PR changes Method (just Method not all graphs) to always have a single
return argument.

This is part 1 in a set of changes that will enable us to have better handling if early return statements.
The simplification that this change provides greatly reduces the work for the next step.

This change makes it so that Method and Python handle multiple returns in the same way:
* 0 - None
* 1 - <single value>
* many - Tuple[...]

The result is that a lot of special-case handling in compiler.cpp and its
bindings can be removed. It also fixes several bugs in return handling,
including one where return values were not always checked against their
attributed values.

Notes:
* inferTypeFrom is renamed to be more accurate and discourage use.
* This has uncovered some bugs in other components, which are noted in
  the diff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15289

Differential Revision: D13481649

Pulled By: zdevito

fbshipit-source-id: 0e2242a40bb28cca2d0e8be48bede96195e4858c

40 files changed:
aten/src/ATen/core/jit_type.h
aten/src/ATen/core/type.cpp
test/expect/TestBatched.test_for.expect
test/expect/TestBatched.test_if_else.expect
test/expect/TestBatched.test_if_else_with_scalar.expect
test/expect/TestBatched.test_if_noelse.expect
test/expect/TestBatched.test_if_noelse_with_scalar.expect
test/expect/TestBatched.test_while.expect
test/expect/TestFuser.test_lstm_cuda-forward.expect
test/expect/TestFuser.test_lstm_traced_cuda.expect
test/expect/TestFuser.test_milstm_cuda-forward.expect
test/expect/TestJit.test_cu_escaped_number.expect
test/expect/TestJit.test_pretty_printer-loop_use_test.expect
test/expect/TestJit.test_pretty_printer-print_weird_test.expect
test/expect/TestJit.test_repeated_output.expect
test/expect/TestJit.test_trace_tuple.expect
test/expect/TestScript.test_augmented_assign.expect
test/expect/TestScript.test_constant_pooling.expect
test/expect/TestScript.test_if_supertype.expect
test/expect/TestScript.test_mutable_dce_graph_input.expect
test/expect/TestScript.test_python_frontend.expect
test/expect/TestScript.test_tuple_indexing.expect
test/expect/TestScript.test_tuple_slicing.expect
test/test_jit.py
torch/csrc/jit/init.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/passes/to_batch.cpp
torch/csrc/jit/passes/to_batch.h
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/python_ir.cpp
torch/csrc/jit/python_tracer.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/compiler.h
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.h
torch/csrc/jit/script/parser.cpp
torch/csrc/jit/script/python_tree_views.cpp
torch/csrc/jit/script/tree_views.h
torch/csrc/jit/tracer.h
torch/jit/frontend.py

index c057963..000c6fa 100644 (file)
@@ -916,7 +916,7 @@ template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListTyp
 template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); }
 template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); }
 
-CAFFE2_API TypePtr inferTypeFrom(const IValue& value);
+CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value);
 
 using TypeEnv = std::unordered_map<std::string, TypePtr>;
 struct MatchTypeReturn {
index e381a71..9c69a49 100644 (file)
@@ -116,7 +116,13 @@ ListTypePtr ListType::ofBools() {
   return value;
 }
 
-TypePtr inferTypeFrom(const IValue& value) {
+// why incomplete? You cannot completely recover a type from
+// an IValue, List[List[int]] and List[List[Tensor]] will both
+// become ivalue.isGenericList() and cannot be recovered.
+// The only appropriate place to use this is where you know that
+// you are only dealing with a subset of objects where you can recover
+// the type, like in the tracer.
+TypePtr incompleteInferTypeFrom(const IValue& value) {
   if (value.isTensor()) {
     return CompleteTensorType::create(value.toTensor());
   } else if (value.isDouble()) {
@@ -136,11 +142,11 @@ TypePtr inferTypeFrom(const IValue& value) {
   } else if (value.isDoubleList()) {
     return ListType::ofFloats();
   } else if (value.isTuple()) {
-    return TupleType::create(fmap(value.toTuple()->elements(), inferTypeFrom));
+    return TupleType::create(fmap(value.toTuple()->elements(), incompleteInferTypeFrom));
   } else if (value.isDevice()) {
     return DeviceObjType::get();
   }
-  AT_ASSERTM(false, "Unhandled IValue kind in inferTypeFrom");
+  AT_ERROR("Type cannot be accurately recovered from this IValue.");
 }
 
 c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
index 9bf819d..ac0e01b 100644 (file)
@@ -17,5 +17,6 @@ graph(%x.1_data : Tensor
       %data : Tensor = aten::where(%mask, %data.1, %5_data)
       -> (%7, %data, %mask, %dims)
     }
-  return (%x, %10, %11);
+  %22 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%x, %10, %11)
+  return (%22);
 }
index f4ae4ec..0a1261d 100644 (file)
@@ -7,33 +7,31 @@ graph(%a.1_data : Tensor
   %6 : int = prim::Constant[value=1]()
   %7 : Tensor = aten::gt(%a.1_data, %b_data)
   %8 : Tensor = aten::mul(%a.1_mask, %b_mask)
-  %9 : Tensor = aten::__or__(%a.1_dims, %b_dims)
-  %10 : bool = prim::Bool(%7)
-  %11 : Long() = prim::NumToTensor(%6)
-  %alpha.1 : float = prim::Float(%11)
+  %9 : Long() = prim::NumToTensor(%6)
+  %alpha.1 : float = prim::Float(%9)
   %data.1 : Tensor = aten::add(%a.1_data, %b_data, %alpha.1)
   %mask.1 : Tensor = aten::mul(%a.1_mask, %b_mask)
   %dims.1 : Tensor = aten::__or__(%a.1_dims, %b_dims)
-  %16 : Long() = prim::NumToTensor(%6)
-  %alpha : float = prim::Float(%16)
+  %14 : Long() = prim::NumToTensor(%6)
+  %alpha : float = prim::Float(%14)
   %data : Tensor = aten::sub(%a.1_data, %b_data, %alpha)
   %mask : Tensor = aten::mul(%a.1_mask, %b_mask)
   %dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
-  %21 : bool = prim::Constant[value=1]()
-  %22 : int = prim::Constant[value=1]()
-  %23 : Tensor = aten::type_as(%8, %7)
-  %data.2 : Tensor = aten::mul(%7, %23)
-  %25 : int = aten::dim(%data.2)
-  %26 : bool = aten::eq(%25, %22)
-  %cond_data : Tensor, %cond_mask : Tensor = prim::If(%26)
+  %19 : bool = prim::Constant[value=1]()
+  %20 : int = prim::Constant[value=1]()
+  %21 : Tensor = aten::type_as(%8, %7)
+  %data.2 : Tensor = aten::mul(%7, %21)
+  %23 : int = aten::dim(%data.2)
+  %24 : bool = aten::eq(%23, %20)
+  %cond_data : Tensor, %cond_mask : Tensor = prim::If(%24)
     block0() {
-      %29 : int = aten::dim(%data.1)
-      %30 : int = aten::sub(%29, %22)
-      %data.4 : Tensor = prim::Loop(%30, %21, %data.2)
-        block0(%32 : int, %33 : Tensor) {
-          %34 : int = aten::dim(%33)
-          %data.3 : Tensor = aten::unsqueeze(%33, %34)
-          -> (%21, %data.3)
+      %27 : int = aten::dim(%data.1)
+      %28 : int = aten::sub(%27, %20)
+      %data.4 : Tensor = prim::Loop(%28, %19, %data.2)
+        block0(%30 : int, %31 : Tensor) {
+          %32 : int = aten::dim(%31)
+          %data.3 : Tensor = aten::unsqueeze(%31, %32)
+          -> (%19, %data.3)
         }
       %cond_data.1 : Tensor = aten::expand_as(%data.4, %data.1)
       %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask.1)
@@ -45,5 +43,6 @@ graph(%a.1_data : Tensor
   %res_data : Tensor = aten::where(%cond_data, %data.1, %data)
   %res_mask : Tensor = aten::where(%cond_mask, %mask.1, %mask)
   %res_dims : Tensor = aten::__or__(%dims.1, %dims)
-  return (%res_data, %res_mask, %res_dims);
+  %39 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
+  return (%39);
 }
index 5bbf309..e1ba887 100644 (file)
@@ -9,32 +9,31 @@ graph(%a.1_data : Tensor
   %8 : Float() = prim::NumToTensor(%7)
   %other : float = prim::Float(%8)
   %10 : Tensor = aten::gt(%a.1_data, %other)
-  %11 : bool = prim::Bool(%10)
-  %12 : Long() = prim::NumToTensor(%6)
-  %alpha.1 : float = prim::Float(%12)
+  %11 : Long() = prim::NumToTensor(%6)
+  %alpha.1 : float = prim::Float(%11)
   %data.1 : Tensor = aten::add(%a.1_data, %b_data, %alpha.1)
   %mask.1 : Tensor = aten::mul(%a.1_mask, %b_mask)
   %dims.1 : Tensor = aten::__or__(%a.1_dims, %b_dims)
-  %17 : Long() = prim::NumToTensor(%6)
-  %alpha : float = prim::Float(%17)
+  %16 : Long() = prim::NumToTensor(%6)
+  %alpha : float = prim::Float(%16)
   %data : Tensor = aten::sub(%a.1_data, %b_data, %alpha)
   %mask : Tensor = aten::mul(%a.1_mask, %b_mask)
   %dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
-  %22 : bool = prim::Constant[value=1]()
-  %23 : int = prim::Constant[value=1]()
-  %24 : Tensor = aten::type_as(%a.1_mask, %10)
-  %data.2 : Tensor = aten::mul(%10, %24)
-  %26 : int = aten::dim(%data.2)
-  %27 : bool = aten::eq(%26, %23)
-  %cond_data : Tensor, %cond_mask : Tensor = prim::If(%27)
+  %21 : bool = prim::Constant[value=1]()
+  %22 : int = prim::Constant[value=1]()
+  %23 : Tensor = aten::type_as(%a.1_mask, %10)
+  %data.2 : Tensor = aten::mul(%10, %23)
+  %25 : int = aten::dim(%data.2)
+  %26 : bool = aten::eq(%25, %22)
+  %cond_data : Tensor, %cond_mask : Tensor = prim::If(%26)
     block0() {
-      %30 : int = aten::dim(%data.1)
-      %31 : int = aten::sub(%30, %23)
-      %data.4 : Tensor = prim::Loop(%31, %22, %data.2)
-        block0(%33 : int, %34 : Tensor) {
-          %35 : int = aten::dim(%34)
-          %data.3 : Tensor = aten::unsqueeze(%34, %35)
-          -> (%22, %data.3)
+      %29 : int = aten::dim(%data.1)
+      %30 : int = aten::sub(%29, %22)
+      %data.4 : Tensor = prim::Loop(%30, %21, %data.2)
+        block0(%32 : int, %33 : Tensor) {
+          %34 : int = aten::dim(%33)
+          %data.3 : Tensor = aten::unsqueeze(%33, %34)
+          -> (%21, %data.3)
         }
       %cond_data.1 : Tensor = aten::expand_as(%data.4, %data.1)
       %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask.1)
@@ -46,5 +45,6 @@ graph(%a.1_data : Tensor
   %res_data : Tensor = aten::where(%cond_data, %data.1, %data)
   %res_mask : Tensor = aten::where(%cond_mask, %mask.1, %mask)
   %res_dims : Tensor = aten::__or__(%dims.1, %dims)
-  return (%res_data, %res_mask, %res_dims);
+  %41 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
+  return (%41);
 }
index c749a9b..c5eb2ef 100644 (file)
@@ -7,28 +7,26 @@ graph(%a.1_data : Tensor
   %6 : int = prim::Constant[value=1]()
   %7 : Tensor = aten::gt(%a.1_data, %b_data)
   %8 : Tensor = aten::mul(%a.1_mask, %b_mask)
-  %9 : Tensor = aten::__or__(%a.1_dims, %b_dims)
-  %10 : bool = prim::Bool(%7)
-  %11 : Long() = prim::NumToTensor(%6)
-  %alpha : float = prim::Float(%11)
+  %9 : Long() = prim::NumToTensor(%6)
+  %alpha : float = prim::Float(%9)
   %data : Tensor = aten::add(%a.1_data, %b_data, %alpha)
   %mask : Tensor = aten::mul(%a.1_mask, %b_mask)
   %dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
-  %16 : bool = prim::Constant[value=1]()
-  %17 : int = prim::Constant[value=1]()
-  %18 : Tensor = aten::type_as(%8, %7)
-  %data.2 : Tensor = aten::mul(%7, %18)
-  %20 : int = aten::dim(%data.2)
-  %21 : bool = aten::eq(%20, %17)
-  %cond_data : Tensor, %cond_mask : Tensor = prim::If(%21)
+  %14 : bool = prim::Constant[value=1]()
+  %15 : int = prim::Constant[value=1]()
+  %16 : Tensor = aten::type_as(%8, %7)
+  %data.2 : Tensor = aten::mul(%7, %16)
+  %18 : int = aten::dim(%data.2)
+  %19 : bool = aten::eq(%18, %15)
+  %cond_data : Tensor, %cond_mask : Tensor = prim::If(%19)
     block0() {
-      %24 : int = aten::dim(%data)
-      %25 : int = aten::sub(%24, %17)
-      %data.4 : Tensor = prim::Loop(%25, %16, %data.2)
-        block0(%27 : int, %28 : Tensor) {
-          %29 : int = aten::dim(%28)
-          %data.3 : Tensor = aten::unsqueeze(%28, %29)
-          -> (%16, %data.3)
+      %22 : int = aten::dim(%data)
+      %23 : int = aten::sub(%22, %15)
+      %data.4 : Tensor = prim::Loop(%23, %14, %data.2)
+        block0(%25 : int, %26 : Tensor) {
+          %27 : int = aten::dim(%26)
+          %data.3 : Tensor = aten::unsqueeze(%26, %27)
+          -> (%14, %data.3)
         }
       %cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
       %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
@@ -40,5 +38,6 @@ graph(%a.1_data : Tensor
   %res_data : Tensor = aten::where(%cond_data, %data, %a.1_data)
   %res_mask : Tensor = aten::where(%cond_mask, %mask, %a.1_mask)
   %res_dims : Tensor = aten::__or__(%dims, %a.1_dims)
-  return (%res_data, %res_mask, %res_dims);
+  %34 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
+  return (%34);
 }
index ef04cab..0974909 100644 (file)
@@ -9,27 +9,26 @@ graph(%a.1_data : Tensor
   %8 : Float() = prim::NumToTensor(%7)
   %other : float = prim::Float(%8)
   %10 : Tensor = aten::gt(%a.1_data, %other)
-  %11 : bool = prim::Bool(%10)
-  %12 : Long() = prim::NumToTensor(%6)
-  %alpha : float = prim::Float(%12)
+  %11 : Long() = prim::NumToTensor(%6)
+  %alpha : float = prim::Float(%11)
   %data : Tensor = aten::add(%a.1_data, %b_data, %alpha)
   %mask : Tensor = aten::mul(%a.1_mask, %b_mask)
   %dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
-  %17 : bool = prim::Constant[value=1]()
-  %18 : int = prim::Constant[value=1]()
-  %19 : Tensor = aten::type_as(%a.1_mask, %10)
-  %data.2 : Tensor = aten::mul(%10, %19)
-  %21 : int = aten::dim(%data.2)
-  %22 : bool = aten::eq(%21, %18)
-  %cond_data : Tensor, %cond_mask : Tensor = prim::If(%22)
+  %16 : bool = prim::Constant[value=1]()
+  %17 : int = prim::Constant[value=1]()
+  %18 : Tensor = aten::type_as(%a.1_mask, %10)
+  %data.2 : Tensor = aten::mul(%10, %18)
+  %20 : int = aten::dim(%data.2)
+  %21 : bool = aten::eq(%20, %17)
+  %cond_data : Tensor, %cond_mask : Tensor = prim::If(%21)
     block0() {
-      %25 : int = aten::dim(%data)
-      %26 : int = aten::sub(%25, %18)
-      %data.4 : Tensor = prim::Loop(%26, %17, %data.2)
-        block0(%28 : int, %29 : Tensor) {
-          %30 : int = aten::dim(%29)
-          %data.3 : Tensor = aten::unsqueeze(%29, %30)
-          -> (%17, %data.3)
+      %24 : int = aten::dim(%data)
+      %25 : int = aten::sub(%24, %17)
+      %data.4 : Tensor = prim::Loop(%25, %16, %data.2)
+        block0(%27 : int, %28 : Tensor) {
+          %29 : int = aten::dim(%28)
+          %data.3 : Tensor = aten::unsqueeze(%28, %29)
+          -> (%16, %data.3)
         }
       %cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
       %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
@@ -41,5 +40,6 @@ graph(%a.1_data : Tensor
   %res_data : Tensor = aten::where(%cond_data, %data, %a.1_data)
   %res_mask : Tensor = aten::where(%cond_mask, %mask, %a.1_mask)
   %res_dims : Tensor = aten::__or__(%dims, %a.1_dims)
-  return (%res_data, %res_mask, %res_dims);
+  %36 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
+  return (%36);
 }
index b8f9d03..4ee7999 100644 (file)
@@ -9,38 +9,35 @@ graph(%a.1_data : Tensor
   %8 : Tensor = aten::gt(%a.1_data, %b_data)
   %9 : Tensor = aten::mul(%a.1_mask, %b_mask)
   %10 : Tensor = aten::__or__(%a.1_dims, %b_dims)
-  %11 : bool = prim::Bool(%8)
-  %12 : int = prim::Constant[value=0]()
-  %13 : Tensor = aten::mul(%8, %9)
-  %14 : Tensor = aten::sum(%13)
-  %15 : Tensor = aten::gt(%14, %12)
-  %16 : bool = prim::Bool(%15)
-  %17 : Tensor, %18 : Tensor, %19 : Tensor, %a : Tensor, %21 : Tensor, %22 : Tensor = prim::Loop(%7, %16, %8, %9, %10, %a.1_data, %a.1_mask, %a.1_dims)
-    block0(%loop_num : int, %cond_data.2 : Tensor, %cond_mask.2 : Tensor, %cond_dims : Tensor, %6_data : Tensor, %6_mask : Tensor, %6_dims : Tensor) {
-      %30 : Long() = prim::NumToTensor(%6)
-      %alpha : float = prim::Float(%30)
+  %11 : int = prim::Constant[value=0]()
+  %12 : Tensor = aten::mul(%8, %9)
+  %13 : Tensor = aten::sum(%12)
+  %14 : Tensor = aten::gt(%13, %11)
+  %15 : bool = prim::Bool(%14)
+  %16 : Tensor, %17 : Tensor, %a : Tensor, %19 : Tensor, %20 : Tensor = prim::Loop(%7, %15, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
+    block0(%loop_num : int, %cond_data.2 : Tensor, %cond_mask.2 : Tensor, %6_data : Tensor, %6_mask : Tensor, %6_dims : Tensor) {
+      %27 : Long() = prim::NumToTensor(%6)
+      %alpha : float = prim::Float(%27)
       %data : Tensor = aten::sub(%6_data, %b_data, %alpha)
       %mask : Tensor = aten::mul(%6_mask, %b_mask)
       %dims : Tensor = aten::__or__(%6_dims, %b_dims)
-      %35 : Tensor = aten::gt(%data, %b_data)
-      %36 : Tensor = aten::mul(%mask, %b_mask)
-      %37 : Tensor = aten::__or__(%dims, %b_dims)
-      %38 : bool = prim::Bool(%35)
-      %39 : bool = prim::Constant[value=1]()
-      %40 : int = prim::Constant[value=1]()
-      %41 : Tensor = aten::type_as(%cond_mask.2, %cond_data.2)
-      %data.2 : Tensor = aten::mul(%cond_data.2, %41)
-      %43 : int = aten::dim(%data.2)
-      %44 : bool = aten::eq(%43, %40)
-      %cond_data : Tensor, %cond_mask : Tensor = prim::If(%44)
+      %32 : Tensor = aten::gt(%data, %b_data)
+      %33 : Tensor = aten::mul(%mask, %b_mask)
+      %34 : bool = prim::Constant[value=1]()
+      %35 : int = prim::Constant[value=1]()
+      %36 : Tensor = aten::type_as(%cond_mask.2, %cond_data.2)
+      %data.2 : Tensor = aten::mul(%cond_data.2, %36)
+      %38 : int = aten::dim(%data.2)
+      %39 : bool = aten::eq(%38, %35)
+      %cond_data : Tensor, %cond_mask : Tensor = prim::If(%39)
         block0() {
-          %47 : int = aten::dim(%data)
-          %48 : int = aten::sub(%47, %40)
-          %data.4 : Tensor = prim::Loop(%48, %39, %data.2)
-            block0(%50 : int, %51 : Tensor) {
-              %52 : int = aten::dim(%51)
-              %data.3 : Tensor = aten::unsqueeze(%51, %52)
-              -> (%39, %data.3)
+          %42 : int = aten::dim(%data)
+          %43 : int = aten::sub(%42, %35)
+          %data.4 : Tensor = prim::Loop(%43, %34, %data.2)
+            block0(%45 : int, %46 : Tensor) {
+              %47 : int = aten::dim(%46)
+              %data.3 : Tensor = aten::unsqueeze(%46, %47)
+              -> (%34, %data.3)
             }
           %cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
           %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
@@ -52,12 +49,13 @@ graph(%a.1_data : Tensor
       %res_data : Tensor = aten::where(%cond_data, %data, %6_data)
       %res_mask : Tensor = aten::where(%cond_mask, %mask, %6_mask)
       %res_dims : Tensor = aten::__or__(%dims, %6_dims)
-      %59 : int = prim::Constant[value=0]()
-      %60 : Tensor = aten::mul(%35, %36)
-      %61 : Tensor = aten::sum(%60)
-      %62 : Tensor = aten::gt(%61, %59)
-      %63 : bool = prim::Bool(%62)
-      -> (%63, %35, %36, %37, %res_data, %res_mask, %res_dims)
+      %54 : int = prim::Constant[value=0]()
+      %55 : Tensor = aten::mul(%32, %33)
+      %56 : Tensor = aten::sum(%55)
+      %57 : Tensor = aten::gt(%56, %54)
+      %58 : bool = prim::Bool(%57)
+      -> (%58, %32, %33, %res_data, %res_mask, %res_dims)
     }
-  return (%a, %21, %22);
+  %59 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%a, %19, %20)
+  return (%59);
 }
index 2933e1b..55fbd37 100644 (file)
@@ -6,7 +6,8 @@ graph(%x : Float(*, *)
       %b_ih : Float(*)
       %b_hh : Float(*)) {
   %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
-  return (%hy, %cy);
+  %9 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
+  return (%9);
 }
 with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
       %1 : Float(*)
index 5bd3f4b..3a02e32 100644 (file)
@@ -13,7 +13,8 @@ graph(%input_1 : Float(*, *)
   %12 : Tensor[] = aten::broadcast_tensors(%11)
   %13 : Tensor, %14 : Tensor, %15 : Tensor, %16 : Tensor = prim::ListUnpack(%12)
   %17 : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
-  return (%17, %cy);
+  %19 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%17, %cy)
+  return (%19);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
index 28c7cd5..dd68b64 100644 (file)
@@ -8,7 +8,8 @@ graph(%x : Float(*, *)
       %beta_h : Float(*)
       %bias : Float(*)) {
   %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %bias, %beta_h, %beta_i, %alpha, %hx, %w_hh, %x, %w_ih)
-  return (%hy, %cy);
+  %11 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
+  return (%11);
 }
 with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
       %1 : Float(*)
index cdd788b..49fc2f6 100644 (file)
@@ -1,3 +1,4 @@
 def graph(self,
-    a: Tensor) -> Tuple[]:
+    a: Tensor) -> None:
   print("hi\016")
+  return None
index 7442d34..b97b33b 100644 (file)
@@ -7,4 +7,4 @@ def graph(self,
   while _0:
     y_2 = torch.add_(y, 1, 1)
     _0, y, z = bool(torch.lt(y_2, 8)), y_2, x
-  return x, z
+  return (x, z)
index b28fa6c..eb5b0e2 100644 (file)
@@ -2,5 +2,6 @@ graph(%a : Tensor
       %b : Tensor) {
   %2 : int = prim::Constant[value=1]()
   %3 : Tensor = aten::add(%a, %b, %2)
-  return (%3, %3);
+  %4 : (Tensor, Tensor) = prim::TupleConstruct(%3, %3)
+  return (%4);
 }
index 249e266..a46fc5f 100644 (file)
@@ -4,5 +4,6 @@ graph(%x : Double(2, 2)
   %4 : Double(2, 2) = aten::mul(%x, %3)
   %5 : Double(2, 2) = aten::mul(%x, %2)
   %6 : (Double(2, 2), Double(2, 2)) = prim::TupleConstruct(%4, %5)
-  return (%x, %6);
+  %7 : (Double(2, 2), (Double(2, 2), Double(2, 2))) = prim::TupleConstruct(%x, %6)
+  return (%7);
 }
index c40d3e5..94aeb9f 100644 (file)
@@ -5,5 +5,6 @@ graph(%a.1 : Tensor
   %a.3 : Tensor = aten::sub_(%a.2, %b, %2)
   %a.4 : Tensor = aten::div_(%a.3, %b)
   %a : Tensor = aten::mul_(%a.4, %b)
-  return (%a, %b);
+  %7 : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
+  return (%7);
 }
index 7f7340e..29e11ac 100644 (file)
@@ -29,5 +29,6 @@ graph(%cond : Tensor) {
        = prim::Print(%d, %e, %d, %5, %y.4, %5)
       -> (%c.1, %y.4)
     }
-  return (%a, %3, %c, %5, %y);
+  %19 : (int, int, int, Tensor, Tensor) = prim::TupleConstruct(%a, %3, %c, %5, %y)
+  return (%19);
 }
index 836ac75..0d516a9 100644 (file)
@@ -9,5 +9,6 @@ graph(%x.1 : Float(*, *)
     block1() {
       -> (%x.1, %x.1, %y.1)
     }
-  return (%x, %y, %z);
+  %7 : (Float(*, *), Tensor, Tensor) = prim::TupleConstruct(%x, %y, %z)
+  return (%7);
 }
index 054ec40..0ac2218 100644 (file)
@@ -8,5 +8,6 @@ graph(%a.1 : Tensor) {
   %7 : int[] = prim::ListConstruct(%5, %6)
   %8 : Tensor = aten::rand(%7, %4, %3, %2)
   %a : Tensor = aten::add_(%a.1, %8, %1)
-  return ();
+  %10 : None = prim::None()
+  return (%10);
 }
index 0e2a47a..649b714 100644 (file)
@@ -69,5 +69,4 @@
     (assert
       (eq (const 1) (const 1))
       (option (string_literal hello)))
-    (return
-      (list (variable (ident x))))))
+    (return (variable (ident x)))))
index 47d845a..5854e9b 100644 (file)
@@ -14,5 +14,6 @@ graph(%a : Tensor) {
     }
   %8 : int = prim::TupleIndex[index=0](%b)
   %9 : int = prim::TupleIndex[index=1](%b)
-  return (%8, %9);
+  %10 : (int, int) = prim::TupleConstruct(%8, %9)
+  return (%10);
 }
index 1628b7b..e1a8749 100644 (file)
@@ -15,6 +15,5 @@ graph(%a : Tensor) {
     }
   %c : (int, int, int, int) = prim::TupleSlice[beg=0, end=4](%b)
   %e : (int, int) = prim::TupleSlice[beg=1, end=3](%c)
-  %11 : int, %12 : int = prim::TupleUnpack(%e)
-  return (%11, %12);
+  return (%e);
 }
index 8c1acf0..89ba4cc 100644 (file)
@@ -182,8 +182,12 @@ def get_execution_plan(graph_executor_state):
 
 
 def get_grad_executor(plan_state, diff_graph_idx=None):
-    if diff_graph_idx is None and len(list(plan_state.graph.nodes())) != 1:
-        raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
+    if diff_graph_idx is None:
+        nodes = list(plan_state.graph.nodes())
+        if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"):
+            pass
+        else:
+            raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
     grad_executors = list(plan_state.code.grad_executors())
     return grad_executors[diff_graph_idx or 0]
 
@@ -262,7 +266,6 @@ class JitTestCase(TestCase):
             ppv = "op_version_set = 0\n{}".format(pp)
             sm = copy_structure_and_params(module)
             torch._C._jit_import_methods(sm, ppv, constant_table)
-
             pp2, _ = sm._python_print()
             if pp != pp2:
                 self.assertMultiLineEqual(pp, pp2)
@@ -553,6 +556,14 @@ class TestJit(JitTestCase):
         with self.assertRaisesRegex(pickle.PickleError, "not supported"):
             torch.save(FooToPickle(), "will_fail")
 
+    def test_single_tuple_trace(self):
+        x = torch.tensor(2.)
+
+        def f2(x):
+            return (x,)
+        jit_f2 = torch.jit.trace(f2, x)
+        assert f2(x) == jit_f2(x)  # fails
+
     @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
     def test_restore_device_cuda(self):
         class MyModule(torch.jit.ScriptModule):
@@ -4272,7 +4283,7 @@ a")
         self.checkScript(one_return, [a], optimize=True)
         self.checkScript(multiple_returns, [a], optimize=True)
 
-        with self.assertRaisesRegex(RuntimeError, "Expected 1 return value"):
+        with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
             @torch.jit.script
             def no_return_bad_annotation(a):
                 # type: (Tensor) -> Tensor
@@ -7549,10 +7560,7 @@ a")
         self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
         tuple_comp = torch.jit.script(tuple_index)
         self.assertExpectedGraph(tuple_comp.graph)
-        self.run_pass('lower_all_tuples', tuple_comp.graph)
-        m = torch.jit.ScriptModule()
-        m._create_method_from_graph("forward", tuple_comp.graph)
-        self.assertEqual(m(torch.tensor(1)), (1, 2))
+        self.assertEqual(tuple_comp(torch.tensor(1)), (1, 2))
 
         with self.assertRaisesRegex(RuntimeError, "tuple indices must be integer constants"):
             @torch.jit.script
@@ -7596,21 +7604,19 @@ a")
             return e
 
         self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
+        tuple_graph = torch.jit.script(tuple_slice)
+        self.assertExpectedGraph(tuple_graph.graph)
+        self.run_pass('lower_all_tuples', tuple_graph.graph)
+        self.assertTrue('Tuple' not in str(tuple_graph.graph))
         tuple_comp = torch.jit.script(tuple_slice)
-        self.assertExpectedGraph(tuple_comp.graph)
-        self.run_pass('lower_all_tuples', tuple_comp.graph)
-        self.assertTrue('Tuple' not in str(tuple_comp.graph))
-        m = torch.jit.ScriptModule()
-        m._create_method_from_graph("forward", tuple_comp.graph)
-        self.assertEqual(m(torch.tensor(1)), (2, 3))
+        self.assertEqual(tuple_comp(torch.tensor(1)), (2, 3))
 
         @torch.jit.script
         def test_indexing_end_out_of_bounds():
             c = (1, 2)
             return c[2:10]
 
-        # output is None in script and () in python
-        self.assertEqual(test_indexing_end_out_of_bounds(), None)
+        self.assertEqual(test_indexing_end_out_of_bounds(), ())
 
     def test_unwrap_optional_builtin(self):
         def test(x):
@@ -7665,9 +7671,7 @@ a")
         self.assertExpected(sm.__getattr__('forward').pretty_print_schema())
 
     def test_annotated_script_fn_return_mismatch(self):
-        with self.assertRaisesRegex(RuntimeError, r"Return value at position 0 was annotated as "
-                                                  r"having type \(Tensor, Tensor\) but is "
-                                                  r"actually of type Tensor"):
+        with self.assertRaisesRegex(RuntimeError, "but is actually of type"):
             @torch.jit.script
             def return_tup(x):
                 # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
@@ -9353,6 +9357,14 @@ class TestPytorchExportModes(JitTestCase):
                            export_type=torch.onnx.ExportTypes.DIRECTORY)
         shutil.rmtree(d)
 
+    def test_onnx_multiple_return(self):
+        @torch.jit.script
+        def foo(a):
+            return (a, a)
+        f = io.BytesIO()
+        x = torch.ones(3)
+        torch.onnx._export(foo, (x,), f, example_outputs=(x, x))
+
     @skipIfRocm
     @skipIfNoLapack
     def test_aten_fallback(self):
@@ -9571,19 +9583,10 @@ def check_alias_annotation(method_name, args, kwargs):
 
 def check_output_types(self, func, ref_outputs, args, kwargs):
     graph = getattr(func, 'last_graph', None)
-    if not isinstance(ref_outputs, tuple):
-        ref_outputs = (ref_outputs,)
     types = [o.type() for o in graph.outputs()]
-    self.assertEqual(len(types), len(ref_outputs))
-    for i, (t, ref_out) in enumerate(zip(types, ref_outputs)):
-        if isinstance(ref_out, list):
-            assert len(ref_out) > 0
-            elem = ref_out[0]
-            assert isinstance(elem, torch.Tensor)
-            self.assertTrue(t.isSubtypeOf(torch._C.ListType.ofTensors()))
-        else:
-            ref_type = torch._C.Type.inferFrom(ref_out)
-            self.assertTrue(ref_type.isSubtypeOf(t))
+    self.assertTrue(len(types) == 1)
+    t = types[0]
+    torch._C._jit_assert_is_instance(ref_outputs, t)
 
 
 def check_against_reference(self, func, reference_func, args, kwargs=None,
@@ -10289,7 +10292,7 @@ class TestFuser(JitTestCase):
         expected1, expected2 = f(x, y)
         self.assertEqual(result1, expected1)
         self.assertEqual(result2, expected2)
-        self.assertAllFused(script_f.graph_for(x, y))
+        self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
 
     # TODO: This test seems dead
     @unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows")
@@ -10372,7 +10375,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
 
         graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
 
-        self.assertGraphSize(graph, 2)
+        self.assertGraphSize(graph, 3)
         self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
 
     def test_merges_without_cycles(self):
@@ -10404,7 +10407,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
 
         graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
 
-        self.assertGraphSize(graph, 1)
+        self.assertGraphSize(graph, 2)
         self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
 
     def test_does_not_create_cycles(self):
@@ -10434,7 +10437,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
 
         graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
 
-        self.assertGraphSize(graph, 2)
+        self.assertGraphSize(graph, 3)
         self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
 
     def test_merges_down(self):
@@ -10449,7 +10452,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
 
         graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
 
-        self.assertGraphSize(graph, 2)
+        self.assertGraphSize(graph, 3)
         self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
 
     def test_respects_lexical_scoping(self):
index 77c9d54..f0e35f6 100644 (file)
@@ -356,6 +356,9 @@ void initJITBindings(PyObject *module) {
     // TODO: this is a fake stub
   });
 
+  m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) {
+    toIValue(obj, type);
+  });
 
   initPythonIRBindings(module);
   tracer::initPythonTracerBindings(module);
index 3930b2d..27e7594 100644 (file)
@@ -562,6 +562,10 @@ struct PythonPrintPass {
       return;
     switch (node->kind()) {
       case prim::Return:
+        if (enforce_importable_ && node->inputs().size() != 1) {
+          throw script::ErrorReport(node->getSourceLocation())
+              << "Exportable methods must have a single return value. Normal use of ScriptMethods should enforce this.";
+        }
         if (node->inputs().size() > 0) {
           indent();
           out << "return ";
@@ -813,7 +817,17 @@ struct PythonPrintPass {
     return out;
   }
 
-  void printDefaultValue(std::ostream& stmt, const IValue& value) {
+  void printDefaultValue(const TypePtr& typ, std::ostream& stmt, const IValue& value) {
+    // xxx - many weak script modules store default values for broadcasting lists
+    // that are not actually the same type as the argument. We can only serialize
+    // default values that will implicitly convert to their declared return type
+    // since we do not need to serialize these built-in modules with their defaults,
+    // we just drop them for now.
+    if (typ->kind() == ListType::Kind &&
+        (value.isInt() || value.isDouble() || value.isBool())) {
+      return;
+    }
+    stmt << "=";
     if (value.isTensor() && !value.toTensor().defined()) {
       // XXX - because undefined tensors are not stored as None, we need special handling.
       // otherwise they get printed as CONSTANTS.c0 and then cannot be recreated because
@@ -856,8 +870,7 @@ struct PythonPrintPass {
       if (defaults_offset != defaults.end()) {
         const c10::optional<IValue>& def = *defaults_offset++;
         if (def) {
-          out << "=";
-          printDefaultValue(out, *def);
+          printDefaultValue(input->type(), out, *def);
         }
       }
     }
index 8877806..37992b0 100644 (file)
@@ -1,5 +1,6 @@
 #include <torch/csrc/jit/passes/to_batch.h>
 #include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
 
 namespace torch { namespace jit {
 
@@ -19,6 +20,15 @@ std::shared_ptr<Graph> ToBatch::getBatchOperator(const std::string& name, int64_
   throw std::runtime_error("function " + name + " with " + std::to_string(num_inputs) + " inputs is not supported in batched tensor yet");
 }
 
+std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
+  auto outputs = script::inlineCallTo(g, callee, inputs);
+  if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
+    auto tc = script::createTupleUnpack(outputs.at(0));
+    outputs = std::vector<Value*>(tc.begin(), tc.end());
+  }
+  return outputs;
+}
+
 // replace aten operator node with BatchTensor operator graph
 void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
   auto res_graph = res_block->owningGraph();
@@ -44,13 +54,13 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
   }
 
   auto batch_graph = getBatchOperator(func_name, new_inputs.size());
-  auto outputs = script::inlineCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
+  auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
 
   // Assume all outputs from inlined operator implementation are in the triple form batched tensor or just a single non-tensor.
   if (outputs.size() == 1) {
     // if previous output is scalar, transform new output back to scalar from dynamic
     TypePtr orig_type = n->outputs()[0]->type();
-    if (orig_type != outputs[0]->type()){
+    if (!orig_type->isSubtypeOf(outputs[0]->type())) {
       Symbol op;
       if (orig_type == IntType::get()) {
         op = prim::Int;
@@ -88,7 +98,7 @@ void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block){
   auto res_graph = res_block->owningGraph();
   auto* r_node = res_graph->createClone(n, rn_fn);
   res_block->appendNode(r_node);
-  auto outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("batch_from_scalar_tensor"), r_node->outputs());
+  auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("batch_from_scalar_tensor"), r_node->outputs());
   batch_map[n->output()] = outputs;
 }
 
@@ -229,7 +239,7 @@ void ToBatch::visitIf(Node* n, Block* block, Block* res_block){
     inputs.insert(inputs.end(), if_output.begin(), if_output.end());
     auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]);
     inputs.insert(inputs.end(), else_output.begin(), else_output.end());
-    auto outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("where", inputs.size()), inputs);
+    auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("where", inputs.size()), inputs);
     batch_map[n->outputs()[i]] = outputs;
   }
 }
@@ -338,7 +348,7 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
   }
   if(cond_is_tensor){
     auto cond = batch_map.at(n->inputs()[1]);
-    auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
+    auto cond_any = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
     rn_env[n->inputs()[1]] =res_graph->insert(prim::Bool, {cond_any[0]});
   }
   for(size_t i = 2; i < n->inputs().size(); i++){
@@ -400,7 +410,7 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
       for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
         inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
       }
-      outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("where"), inputs);
+      outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("where"), inputs);
     }
     else{
       for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
@@ -408,7 +418,7 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
       }
       auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
       inputs.insert(inputs.end(), data.begin(), data.end());
-      outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("update"), inputs);
+      outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("update"), inputs);
     }
     batch_map[n->outputs()[i]] = outputs;
     for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
@@ -419,7 +429,7 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
   // update loop conditions
   if(cond_is_tensor){
     auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
-    auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
+    auto cond_any = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
     auto to_bool_output = res_graph->insert(prim::Bool, {cond_any[0]});
     loop_block->insertOutput(0,  to_bool_output);
     for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
@@ -513,17 +523,33 @@ void ToBatch::toBatch(Block* block, Block* res_block) {
   }
 }
 
-std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph>& graph){
+std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph) {
+  // lower the tuple before the pass
+  if (graph->outputs().at(0)->type()->kind() == TupleType::Kind) {
+    graph = graph->copy();
+    auto outs = script::createTupleUnpack(graph->outputs().at(0));
+    graph->eraseOutput(0);
+    for(auto o : outs)
+      graph->registerOutput(o);
+    EliminateDeadCode(graph->block());
+  }
   std::shared_ptr<Graph> res_graph = std::make_shared<Graph>();
   ToBatch to_batch;
   to_batch.toBatch(graph->block(), res_graph->block());
 
+  // methods should only have a single output, so we pack everything into a tuple
+  auto tup = res_graph->insertNode(res_graph->createTuple(res_graph->outputs()));
+  while (res_graph->outputs().size() > 0)
+    res_graph->eraseOutput(res_graph->outputs().size() - 1);
+  res_graph->registerOutput(tup->output());
+  EliminateDeadCode(res_graph->block());
+
   return res_graph;
 }
 
 void initRegisterBatchOpsBindings(PyObject* module) {
   auto m = py::handle(module).cast<py::module>();
-  m.def("to_batch_graph", &to_batch_graph);
+  m.def("to_batch_graph", to_batch_graph);
   m.def("register_batch_operator", [](std::string name, std::shared_ptr<Graph> graph){
     ToBatch::batch_operator_table[name].push_back(graph);
   });
index 4c2a098..959f265 100644 (file)
@@ -33,6 +33,6 @@ public:
   TORCH_API void toBatch(Block* block, Block* res_block);
 };
 
-TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph>& graph);
+TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph);
 TORCH_API void initRegisterBatchOpsBindings(PyObject* module);
 }}
index 2b6c336..5089ada 100644 (file)
@@ -408,7 +408,7 @@ inline py::object invokeScriptMethodFromPython(
     AutoNoGIL no_gil_guard;
     method.run(stack);
   }
-  return createPyObjectForStack(std::move(stack));
+  return toPyObject(std::move(stack.back()));
 }
 
 inline py::object invokeOperatorFromPython(
index c34335a..5a39fe3 100644 (file)
@@ -449,8 +449,7 @@ void initPythonIRBindings(PyObject * module_) {
     })
     .def("isSubtypeOf", [](std::shared_ptr<Type>& self, std::shared_ptr<Type> other) {
         return self->isSubtypeOf(other);
-    })
-    .def_static("inferFrom", c10::inferTypeFrom);
+    });
 
   py::class_<NumberType, Type, std::shared_ptr<NumberType>>(m, "NumberType")
     .def_static("get", &NumberType::get);
index a83b645..30210a1 100644 (file)
@@ -60,10 +60,7 @@ std::shared_ptr<torch::jit::Graph> createGraphByTracing(
       AT_ERROR("The traced function didn't return any values! Side-effects are not "
                "captured in traces, so it would be a no-op.");
     }
-    if (!PyTuple_Check(out.ptr())) {
-      out = py::make_tuple(out);
-    }
-    tracer::exit(toStack(out));
+    tracer::exit({toIValue(out)});
     auto graph = enter_info.first->graph;
     EliminateDeadCode(graph);
     LowerSimpleTuples(graph);
index 5b018a5..ad6c186 100644 (file)
@@ -448,13 +448,32 @@ Value* tryConvertToType(
     const TypePtr& concrete_type,
     Value* value,
     bool allow_conversions) {
-  // Allow homogeneous tuples to be casted implicitly to lists of appropriate
-  // types
-  if (convertibleToList(value->type(), unwrapOptional(concrete_type)) &&
-      value->type()->kind() == TypeKind::TupleType) {
-    auto unpacked = createTupleUnpack(value);
-    auto elem_type = unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
-    value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
+
+  if (auto value_tuple = value->type()->cast<TupleType>()) {
+    // Allow homogeneous tuples to be casted implicitly to lists of appropriate
+    // types
+    if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
+      auto unpacked = createTupleUnpack(value);
+      auto elem_type = unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
+      value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
+    }
+    // inductively apply implicit conversions to tuples
+    if (auto concrete_tuple = concrete_type->cast<TupleType>()) {
+      if (!value_tuple->isSubtypeOf(concrete_tuple) &&
+          concrete_tuple->elements().size() == value_tuple->elements().size()) {
+        auto unpacked = createTupleUnpack(value);
+        std::vector<Value*> converted;
+        for (size_t i = 0; i < concrete_tuple->elements().size(); ++i) {
+          converted.emplace_back(tryConvertToType(
+              loc,
+              graph,
+              concrete_tuple->elements().at(i),
+              unpacked.at(i),
+              allow_conversions));
+        }
+        value = graph.insertNode(graph.createTuple(converted))->output();
+      }
+    }
   }
 
   if (value->type()->isSubtypeOf(NoneType::get()) && !concrete_type->isSubtypeOf(NoneType::get())){
@@ -834,9 +853,9 @@ struct to_ir {
     if (self && def.decl().params().size() == 0) {
       throw ErrorReport(def.decl().params().range()) << "methods must have a self argument";
     }
-
     auto schema = extractSchemaFromDef(def);
     std::vector<Argument> arguments = emitFormalArguments(self, schema);
+
     // body
     auto stmts = def.statements();
     auto stmts_begin = stmts.begin();
@@ -847,7 +866,8 @@ struct to_ir {
       return_stmt = Return(*stmts_end);
     }
     emitStatements(stmts_begin, stmts_end);
-    std::vector<Argument> returns = emitReturn(return_stmt, schema);
+    std::vector<Argument> returns = {emitReturn(
+        return_stmt ? return_stmt->range() : def.range(), return_stmt, schema)};
 
     method.setSchema({def.name().name(), std::move(arguments), std::move(returns)});
     // remove any uses of tuples that we inserted that are not needed
@@ -895,7 +915,7 @@ private:
         Decl::create(r, List<Param>::create(r, {}), Maybe<Expr>::create(r, tuple_type));
 
     auto tuple_expr = TupleLiteral::create(r, List<Expr>::create(r, default_exprs));
-    auto ret = Return::create(r, List<Expr>::create(r, { tuple_expr }));
+    auto ret = Return::create(r, tuple_expr);
     auto def = Def::create(
         r,
         Ident::create(r, "defaults"),
@@ -903,8 +923,9 @@ private:
         List<Stmt>::create(r, {ret}));
     auto m = std::make_shared<Module>();
     defineMethodsInModule(m, {def}, {resolver}, nullptr);
-    m->get_method("defaults").run(default_values);
-    return default_values;
+    Stack stack;
+    m->get_method("defaults").run(stack);
+    return stack.at(0).toTuple()->elements();
   }
 
   std::vector<Argument> parseArgsFromDecl(const Decl& decl) {
@@ -957,44 +978,29 @@ private:
     return retval;
   }
 
-  std::vector<Argument> parseReturnsFromDecl(const Decl& decl) {
-    JIT_ASSERT(decl.return_type().present());
+  std::vector<Argument> parseReturnFromDecl(const Decl& decl) {
+    // we represent no annoation on a return type as having no values in the
+    // schema's return() list
+    // in emitReturn we take the actual return value to be the value of the return
+    // statement if no one was provided here
+    if(!decl.return_type().present())
+      return {};
+
     if (handleBroadcastList(decl.return_type().get()))
       throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type";
     auto parsed_type = parseTypeFromExpr(decl.return_type().get());
-    if (auto tuple_type = parsed_type->cast<TupleType>()) {
-      // Flatten a single return type of type Tuple into its constituent types
-      std::vector<Argument> retval;
-      for (auto type_ptr : tuple_type->elements()) {
-        retval.emplace_back(
-            "",
-            type_ptr,
-            /*N =*/c10::nullopt,
-            /*default_value =*/c10::nullopt,
-            /*kwarg_only =*/false);
-      }
-      return retval;
-    } else {
-      return {Argument(
-          "",
-          parsed_type,
-          /*N =*/c10::nullopt,
-          /*default_value =*/c10::nullopt,
-          /*kwarg_only =*/false)};
-    }
+    return {Argument(
+        "",
+        parsed_type,
+        /*N =*/c10::nullopt,
+        /*default_value =*/c10::nullopt,
+        /*kwarg_only =*/false)};
   }
   FunctionSchema extractSchemaFromDef(const Def &def) {
       auto name = def.name().name();
       std::vector<Argument> args = parseArgsFromDecl(def.decl());
-      std::vector<Argument> returns;
-      bool is_varret;
-      if (def.decl().return_type().present()) {
-        returns = parseReturnsFromDecl(def.decl());
-        is_varret = false;
-      } else {
-        is_varret = true;
-      }
-      return FunctionSchema(name, args, returns, false, is_varret);
+      std::vector<Argument> returns = parseReturnFromDecl(def.decl());
+      return FunctionSchema(name, std::move(args), std::move(returns), false, false);
   }
 
   std::vector<Argument> emitFormalArguments(const SugaredValuePtr& self, const FunctionSchema& schema) {
@@ -1030,50 +1036,27 @@ private:
     }
     return arguments;
   }
-  std::vector<Argument> emitReturn(c10::optional<Return> return_stmt_, const FunctionSchema& schema) {
+
+  Argument emitReturn(const SourceRange& range, c10::optional<Return> return_stmt, const FunctionSchema& schema) {
+    JIT_ASSERT(schema.returns().size() <= 1);
     // outputs
-    std::vector<Argument> returns;
-    if (return_stmt_) {
-      auto return_stmt = *return_stmt_;
-      auto results = getValues(return_stmt.values(), true);
-      // a single return value that is a tuple expands in place:
-      // return a
-      if (return_stmt.values().size() == 1 && results.size() == 1) {
-        auto result = results.at(0);
-        if(result->type()->cast<TupleType>()) {
-          results = createTupleUnpack(result).vec();
-        }
-      }
-      if (!schema.is_varret() && schema.returns().size() != results.size()) {
-        throw ErrorReport(def.range()) << "Number of type annotations for function"
-          << " return (" << schema.returns().size() << ") does not match"
-          << " the number of returns from the function (" << results.size() << ")!";
-      }
-      const auto& range = return_stmt.range();
-      size_t return_type_idx = 0;
-      for (auto r : results) {
-        TypePtr type = DynamicType::get();
-        if (!schema.is_varret()) {
-          type = schema.returns().at(return_type_idx).type();
-          r = tryConvertToType(range, *graph, type, r, /*allow_conversions=*/false);
-          if (!r->type()->isSubtypeOf(type)) {
-            throw ErrorReport(return_stmt.range()) << "Return value at position "
-              << return_type_idx << " was annotated as having type " << type->str()
-              << " but is actually of type " << r->type()->str();
-          }
-          return_type_idx++;
-        }
-        graph->registerOutput(r);
-        returns.emplace_back("", type);
-      }
-    } else if (schema.returns().size() > 0) {
-      // schema has returns but there's no return nodes in graph
-      throw ErrorReport() << "Expected " << schema.returns().size()
-                          << " return value"
-                          << (schema.returns().size() > 1 ? "s" : "")
-                          << " but found no return statement";
+    Value* result = return_stmt ? emitExpr(return_stmt->expr())
+                                : graph->insertConstant(IValue(), range);
+    TypePtr result_type = schema.returns().size() > 0
+        ? schema.returns().at(0).type()
+        : result->type();
+
+    if (return_stmt) {
+      result = tryConvertToType(
+          range, *graph, result_type, result, /*allow_conversions=*/true);
+    }
+
+    if (!result->type()->isSubtypeOf(result_type)) {
+      throw ErrorReport(range) << "Return value was annotated as having type " << result_type->python_str()
+        << " but is actually of type " << result->type()->python_str();
     }
-    return returns;
+    graph->registerOutput(result);
+    return Argument("", result_type);
   }
   void emitStatements(const List<Stmt>& statements) {
     return emitStatements(statements.begin(), statements.end());
@@ -2720,6 +2703,7 @@ const std::unordered_map<std::string, TypePtr> &ident_to_type_lut() {
     // technically this is not a python type but we need it when
     // parsing serialized methods that use implicit converions to Scalar
     {"number", NumberType::get()},
+    {"None", NoneType::get()},
   };
   return map;
 }
@@ -2769,12 +2753,15 @@ c10::optional<std::string> parseBaseTypeName(const Expr& expr) {
     case TK_VAR: {
       return Var(expr).name().name();
     }
+    case TK_NONE: {
+      return "None";
+    }
     case '.': {
       auto select = Select(expr);
       const std::string& name = select.selector().name();
       if (isTorch(select.value()) && name == "Tensor")
         return "Tensor";
-    }
+    } break;
   }
   return at::nullopt;
 }
index 7abe842..1730a88 100644 (file)
@@ -277,6 +277,8 @@ TORCH_API c10::optional<size_t> findInputWithName(
   const std::string& name,
   at::ArrayRef<NamedValue> kwargs);
 
+TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
+
 } // namespace script
 } // namespace jit
 } // namespace torch
index 236a727..e0d51bf 100644 (file)
@@ -525,7 +525,7 @@ void initJitScriptBindings(PyObject* module) {
              const std::string& script,
              ResolutionCallback rcb, bool has_self) {
             auto self = has_self ? std::make_shared<ModuleValue>(m) : nullptr;
-            return defineMethodsInModule(m, script, pythonResolver(rcb), self);
+            defineMethodsInModule(m, script, pythonResolver(rcb), self);
           })
       .def("_create_methods", [](std::shared_ptr<Module> m,
           const std::vector<Def>& defs,
@@ -646,7 +646,7 @@ void initJitScriptBindings(PyObject* module) {
         if (self.find_method("forward")) {
           Method & m = self.get_method("forward");
           return m.graph_for(
-              createStackForSchema(m.getSchema(), tuple_slice(std::move(args), 1), std::move(kwargs)));
+              createStackForSchema(m.getSchema(), tuple_slice(std::move(args), 1), kwargs));
         }
         throw std::runtime_error("Attempted to call graph_for on a Module without a compiled forward()");
       })
@@ -708,7 +708,7 @@ void initJitScriptBindings(PyObject* module) {
     .def("graph_for", [](py::args args, py::kwargs kwargs) {
       // see: [pybind11 varargs]
       Method& self = py::cast<Method&>(args[0]);
-      return self.graph_for(createStackForSchema(self.getSchema(), tuple_slice(std::move(args), 1), std::move(kwargs)));
+      return self.graph_for(createStackForSchema(self.getSchema(), tuple_slice(std::move(args), 1), kwargs));
     })
     .def("debug_disable_autodiff_subgraph_inlining", &Method::debugDisableAutodiffSubgraphInlining)
     .def("schema", &Method::getSchema)
index 0b57080..aeed2d8 100644 (file)
@@ -69,9 +69,6 @@ struct Method {
   IValue operator()(std::vector<IValue> stack) {
     checkInputsAgainstSchema(stack);
     run(stack);
-    if (stack.size() != 1) {
-      return Tuple::create(std::move(stack));
-    }
     return stack.front();
   }
 
@@ -144,12 +141,18 @@ struct Method {
       auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
       retval->inputs()[i]->setType(type);
     }
-    JIT_ASSERT(retval->outputs().size() == outputs.size());
+    at::ArrayRef<Value*> output_values = retval->outputs();
+    // patch this to still work if we are returning a tuple of multiple values
+    if (output_values.at(0)->type()->kind() == TupleType::Kind) {
+      JIT_ASSERT(output_values.at(0)->node()->kind()== prim::TupleConstruct);
+      output_values = output_values.at(0)->node()->inputs();
+    }
+    JIT_ASSERT(output_values.size() == outputs.size());
     for (size_t i=0; i < retval->outputs().size(); ++i) {
       auto scalar_type = outputs[i].type().scalarType();
       auto sizes = outputs[i].sizes();
       auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
-      retval->outputs()[i]->setType(type);
+      output_values[i]->setType(type);
     }
     return retval;
   }
@@ -213,7 +216,10 @@ private:
   }
 
   GraphExecutor& get_executor() {
-    std::call_once(executor_init, [&]{
+    std::call_once(executor_init, [&] {
+      AT_CHECK(
+          graph()->outputs().size() == 1,
+          "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
       executor = GraphExecutor(graph(), optimize);
     });
     return executor;
@@ -231,7 +237,10 @@ private:
     for (size_t pos = 0; pos < schema.arguments().size(); ++pos) {
       const auto& argument = schema.arguments()[pos];
       if (pos < inputs.size()) {
-        const TypePtr inputType = inferTypeFrom(inputs[pos]);
+        // XXX - this fails to handle generic aggregates
+        // and should be replaced with a function isSubvalueOf(ivalue, type)
+        // That asks if the specific value is a valid instance of type.
+        const TypePtr inputType = incompleteInferTypeFrom(inputs[pos]);
         AT_CHECK(inputType->isSubtypeOf(argument.type()),
               "Expected value of type ", *argument.type(),
               " for argument '", argument.name(),
index e92acec..65201af 100644 (file)
@@ -365,9 +365,10 @@ struct ParserImpl {
       }
       case TK_RETURN: {
         auto range = L.next().range;
-        // XXX: TK_NEWLINE makes it accept an empty list
-        auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &ParserImpl::parseExp);
-        return Return::create(range, values);
+        Expr value = L.cur().kind != TK_NEWLINE ? parseExpOrExpTuple()
+                                                : Expr(c(TK_NONE, range, {}));
+        L.expect(TK_NEWLINE);
+        return Return::create(range, value);
       }
       case TK_RAISE: {
         auto range = L.next().range;
@@ -461,20 +462,24 @@ struct ParserImpl {
     } while(!L.nextIf(TK_DEDENT));
     return c(TK_LIST, r, std::move(stmts));
   }
-  Decl parseDecl() {
-    auto paramlist = parseList('(', ',', ')', &ParserImpl::parseParam);
-    // Parse return type annotation
-    TreeRef return_type;
+
+  Maybe<Expr> parseReturnAnnotation() {
     if (L.nextIf(TK_ARROW)) {
       // Exactly one expression for return type annotation
       auto return_type_range = L.cur().range;
-      return_type = Maybe<Expr>::create(return_type_range, parseExp());
+      return Maybe<Expr>::create(return_type_range, parseExp());
     } else {
-      // Default to returning single tensor. TODO: better sentinel value?
-      return_type = Maybe<Expr>::create(L.cur().range);
+      return Maybe<Expr>::create(L.cur().range);
     }
+  }
+
+  Decl parseDecl() {
+    auto paramlist = parseList('(', ',', ')', &ParserImpl::parseParam);
+    // Parse return type annotation
+    TreeRef return_type;
+    Maybe<Expr> return_annotation = parseReturnAnnotation();
     L.expect(':');
-    return Decl::create(paramlist.range(), List<Param>(paramlist), Maybe<Expr>(return_type));
+    return Decl::create(paramlist.range(), List<Param>(paramlist), return_annotation);
   }
 
   TreeRef parseFunction(bool is_method) {
index 8d2c39d..d42fc0e 100644 (file)
@@ -132,8 +132,8 @@ void initTreeViewBindings(PyObject *module) {
       return AugAssign::create(r, lhs, kind, rhs);
     }));
   py::class_<Return, Stmt>(m, "Return")
-    .def(py::init([](const SourceRange& range, std::vector<Expr> values) {
-      return Return::create(range, wrap_list(range, std::move(values)));
+    .def(py::init([](const SourceRange& range, Expr* value) {
+      return Return::create(range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
     }));
   py::class_<Raise, Stmt>(m, "Raise")
     .def(py::init([](const SourceRange& range, Expr *expr) {
index 270e1aa..6a8f7b0 100644 (file)
@@ -497,11 +497,11 @@ struct Return : public Stmt {
   explicit Return(const TreeRef& tree) : Stmt(tree) {
     tree_->match(TK_RETURN);
   }
-  List<Expr> values() const {
-    return List<Expr>(subtree(0));
+  Expr expr() const {
+    return Expr(subtree(0));
   }
-  static Return create(const SourceRange& range, const List<Expr>& values) {
-    return Return(Compound::create(TK_RETURN, range, {values}));
+  static Return create(const SourceRange& range, const Expr& value) {
+    return Return(Compound::create(TK_RETURN, range, {value}));
   }
 };
 
index 691a1d9..15de8c7 100644 (file)
@@ -162,7 +162,7 @@ inline std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
     }
   };
   for (IValue& input : inputs) {
-    input = add_input(input, inferTypeFrom(input), state->graph->addInput());
+    input = add_input(input, incompleteInferTypeFrom(input), state->graph->addInput());
   }
   return std::make_pair(state, inputs);
 }
index cd885c5..2c69d82 100644 (file)
@@ -253,8 +253,7 @@ class StmtBuilder(Builder):
     @staticmethod
     def build_Return(ctx, stmt):
         r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("return"))
-        values = (stmt.value,) if not isinstance(stmt.value, ast.Tuple) else stmt.value.elts
-        return Return(r, [build_expr(ctx, val) for val in values if val is not None])
+        return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))
 
     @staticmethod
     def build_Raise(ctx, stmt):