template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); }
template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); }
-CAFFE2_API TypePtr inferTypeFrom(const IValue& value);
+CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value);
using TypeEnv = std::unordered_map<std::string, TypePtr>;
struct MatchTypeReturn {
return value;
}
-TypePtr inferTypeFrom(const IValue& value) {
+// why incomplete? You cannot completely recover a type from
+// an IValue, List[List[int]] and List[List[Tensor]] will both
+// become ivalue.isGenericList() and cannot be recovered.
+// The only appropriate place to use this is where you know that
+// you are only dealing with a subset of objects where you can recover
+// the type, like in the tracer.
+TypePtr incompleteInferTypeFrom(const IValue& value) {
if (value.isTensor()) {
return CompleteTensorType::create(value.toTensor());
} else if (value.isDouble()) {
} else if (value.isDoubleList()) {
return ListType::ofFloats();
} else if (value.isTuple()) {
- return TupleType::create(fmap(value.toTuple()->elements(), inferTypeFrom));
+ return TupleType::create(fmap(value.toTuple()->elements(), incompleteInferTypeFrom));
} else if (value.isDevice()) {
return DeviceObjType::get();
}
- AT_ASSERTM(false, "Unhandled IValue kind in inferTypeFrom");
+ AT_ERROR("Type cannot be accurately recovered from this IValue.");
}
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
%data : Tensor = aten::where(%mask, %data.1, %5_data)
-> (%7, %data, %mask, %dims)
}
- return (%x, %10, %11);
+ %22 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%x, %10, %11)
+ return (%22);
}
%6 : int = prim::Constant[value=1]()
%7 : Tensor = aten::gt(%a.1_data, %b_data)
%8 : Tensor = aten::mul(%a.1_mask, %b_mask)
- %9 : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %10 : bool = prim::Bool(%7)
- %11 : Long() = prim::NumToTensor(%6)
- %alpha.1 : float = prim::Float(%11)
+ %9 : Long() = prim::NumToTensor(%6)
+ %alpha.1 : float = prim::Float(%9)
%data.1 : Tensor = aten::add(%a.1_data, %b_data, %alpha.1)
%mask.1 : Tensor = aten::mul(%a.1_mask, %b_mask)
%dims.1 : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %16 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%16)
+ %14 : Long() = prim::NumToTensor(%6)
+ %alpha : float = prim::Float(%14)
%data : Tensor = aten::sub(%a.1_data, %b_data, %alpha)
%mask : Tensor = aten::mul(%a.1_mask, %b_mask)
%dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %21 : bool = prim::Constant[value=1]()
- %22 : int = prim::Constant[value=1]()
- %23 : Tensor = aten::type_as(%8, %7)
- %data.2 : Tensor = aten::mul(%7, %23)
- %25 : int = aten::dim(%data.2)
- %26 : bool = aten::eq(%25, %22)
- %cond_data : Tensor, %cond_mask : Tensor = prim::If(%26)
+ %19 : bool = prim::Constant[value=1]()
+ %20 : int = prim::Constant[value=1]()
+ %21 : Tensor = aten::type_as(%8, %7)
+ %data.2 : Tensor = aten::mul(%7, %21)
+ %23 : int = aten::dim(%data.2)
+ %24 : bool = aten::eq(%23, %20)
+ %cond_data : Tensor, %cond_mask : Tensor = prim::If(%24)
block0() {
- %29 : int = aten::dim(%data.1)
- %30 : int = aten::sub(%29, %22)
- %data.4 : Tensor = prim::Loop(%30, %21, %data.2)
- block0(%32 : int, %33 : Tensor) {
- %34 : int = aten::dim(%33)
- %data.3 : Tensor = aten::unsqueeze(%33, %34)
- -> (%21, %data.3)
+ %27 : int = aten::dim(%data.1)
+ %28 : int = aten::sub(%27, %20)
+ %data.4 : Tensor = prim::Loop(%28, %19, %data.2)
+ block0(%30 : int, %31 : Tensor) {
+ %32 : int = aten::dim(%31)
+ %data.3 : Tensor = aten::unsqueeze(%31, %32)
+ -> (%19, %data.3)
}
%cond_data.1 : Tensor = aten::expand_as(%data.4, %data.1)
%cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask.1)
%res_data : Tensor = aten::where(%cond_data, %data.1, %data)
%res_mask : Tensor = aten::where(%cond_mask, %mask.1, %mask)
%res_dims : Tensor = aten::__or__(%dims.1, %dims)
- return (%res_data, %res_mask, %res_dims);
+ %39 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
+ return (%39);
}
%8 : Float() = prim::NumToTensor(%7)
%other : float = prim::Float(%8)
%10 : Tensor = aten::gt(%a.1_data, %other)
- %11 : bool = prim::Bool(%10)
- %12 : Long() = prim::NumToTensor(%6)
- %alpha.1 : float = prim::Float(%12)
+ %11 : Long() = prim::NumToTensor(%6)
+ %alpha.1 : float = prim::Float(%11)
%data.1 : Tensor = aten::add(%a.1_data, %b_data, %alpha.1)
%mask.1 : Tensor = aten::mul(%a.1_mask, %b_mask)
%dims.1 : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %17 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%17)
+ %16 : Long() = prim::NumToTensor(%6)
+ %alpha : float = prim::Float(%16)
%data : Tensor = aten::sub(%a.1_data, %b_data, %alpha)
%mask : Tensor = aten::mul(%a.1_mask, %b_mask)
%dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %22 : bool = prim::Constant[value=1]()
- %23 : int = prim::Constant[value=1]()
- %24 : Tensor = aten::type_as(%a.1_mask, %10)
- %data.2 : Tensor = aten::mul(%10, %24)
- %26 : int = aten::dim(%data.2)
- %27 : bool = aten::eq(%26, %23)
- %cond_data : Tensor, %cond_mask : Tensor = prim::If(%27)
+ %21 : bool = prim::Constant[value=1]()
+ %22 : int = prim::Constant[value=1]()
+ %23 : Tensor = aten::type_as(%a.1_mask, %10)
+ %data.2 : Tensor = aten::mul(%10, %23)
+ %25 : int = aten::dim(%data.2)
+ %26 : bool = aten::eq(%25, %22)
+ %cond_data : Tensor, %cond_mask : Tensor = prim::If(%26)
block0() {
- %30 : int = aten::dim(%data.1)
- %31 : int = aten::sub(%30, %23)
- %data.4 : Tensor = prim::Loop(%31, %22, %data.2)
- block0(%33 : int, %34 : Tensor) {
- %35 : int = aten::dim(%34)
- %data.3 : Tensor = aten::unsqueeze(%34, %35)
- -> (%22, %data.3)
+ %29 : int = aten::dim(%data.1)
+ %30 : int = aten::sub(%29, %22)
+ %data.4 : Tensor = prim::Loop(%30, %21, %data.2)
+ block0(%32 : int, %33 : Tensor) {
+ %34 : int = aten::dim(%33)
+ %data.3 : Tensor = aten::unsqueeze(%33, %34)
+ -> (%21, %data.3)
}
%cond_data.1 : Tensor = aten::expand_as(%data.4, %data.1)
%cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask.1)
%res_data : Tensor = aten::where(%cond_data, %data.1, %data)
%res_mask : Tensor = aten::where(%cond_mask, %mask.1, %mask)
%res_dims : Tensor = aten::__or__(%dims.1, %dims)
- return (%res_data, %res_mask, %res_dims);
+ %41 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
+ return (%41);
}
%6 : int = prim::Constant[value=1]()
%7 : Tensor = aten::gt(%a.1_data, %b_data)
%8 : Tensor = aten::mul(%a.1_mask, %b_mask)
- %9 : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %10 : bool = prim::Bool(%7)
- %11 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%11)
+ %9 : Long() = prim::NumToTensor(%6)
+ %alpha : float = prim::Float(%9)
%data : Tensor = aten::add(%a.1_data, %b_data, %alpha)
%mask : Tensor = aten::mul(%a.1_mask, %b_mask)
%dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %16 : bool = prim::Constant[value=1]()
- %17 : int = prim::Constant[value=1]()
- %18 : Tensor = aten::type_as(%8, %7)
- %data.2 : Tensor = aten::mul(%7, %18)
- %20 : int = aten::dim(%data.2)
- %21 : bool = aten::eq(%20, %17)
- %cond_data : Tensor, %cond_mask : Tensor = prim::If(%21)
+ %14 : bool = prim::Constant[value=1]()
+ %15 : int = prim::Constant[value=1]()
+ %16 : Tensor = aten::type_as(%8, %7)
+ %data.2 : Tensor = aten::mul(%7, %16)
+ %18 : int = aten::dim(%data.2)
+ %19 : bool = aten::eq(%18, %15)
+ %cond_data : Tensor, %cond_mask : Tensor = prim::If(%19)
block0() {
- %24 : int = aten::dim(%data)
- %25 : int = aten::sub(%24, %17)
- %data.4 : Tensor = prim::Loop(%25, %16, %data.2)
- block0(%27 : int, %28 : Tensor) {
- %29 : int = aten::dim(%28)
- %data.3 : Tensor = aten::unsqueeze(%28, %29)
- -> (%16, %data.3)
+ %22 : int = aten::dim(%data)
+ %23 : int = aten::sub(%22, %15)
+ %data.4 : Tensor = prim::Loop(%23, %14, %data.2)
+ block0(%25 : int, %26 : Tensor) {
+ %27 : int = aten::dim(%26)
+ %data.3 : Tensor = aten::unsqueeze(%26, %27)
+ -> (%14, %data.3)
}
%cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
%cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
%res_data : Tensor = aten::where(%cond_data, %data, %a.1_data)
%res_mask : Tensor = aten::where(%cond_mask, %mask, %a.1_mask)
%res_dims : Tensor = aten::__or__(%dims, %a.1_dims)
- return (%res_data, %res_mask, %res_dims);
+ %34 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
+ return (%34);
}
%8 : Float() = prim::NumToTensor(%7)
%other : float = prim::Float(%8)
%10 : Tensor = aten::gt(%a.1_data, %other)
- %11 : bool = prim::Bool(%10)
- %12 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%12)
+ %11 : Long() = prim::NumToTensor(%6)
+ %alpha : float = prim::Float(%11)
%data : Tensor = aten::add(%a.1_data, %b_data, %alpha)
%mask : Tensor = aten::mul(%a.1_mask, %b_mask)
%dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %17 : bool = prim::Constant[value=1]()
- %18 : int = prim::Constant[value=1]()
- %19 : Tensor = aten::type_as(%a.1_mask, %10)
- %data.2 : Tensor = aten::mul(%10, %19)
- %21 : int = aten::dim(%data.2)
- %22 : bool = aten::eq(%21, %18)
- %cond_data : Tensor, %cond_mask : Tensor = prim::If(%22)
+ %16 : bool = prim::Constant[value=1]()
+ %17 : int = prim::Constant[value=1]()
+ %18 : Tensor = aten::type_as(%a.1_mask, %10)
+ %data.2 : Tensor = aten::mul(%10, %18)
+ %20 : int = aten::dim(%data.2)
+ %21 : bool = aten::eq(%20, %17)
+ %cond_data : Tensor, %cond_mask : Tensor = prim::If(%21)
block0() {
- %25 : int = aten::dim(%data)
- %26 : int = aten::sub(%25, %18)
- %data.4 : Tensor = prim::Loop(%26, %17, %data.2)
- block0(%28 : int, %29 : Tensor) {
- %30 : int = aten::dim(%29)
- %data.3 : Tensor = aten::unsqueeze(%29, %30)
- -> (%17, %data.3)
+ %24 : int = aten::dim(%data)
+ %25 : int = aten::sub(%24, %17)
+ %data.4 : Tensor = prim::Loop(%25, %16, %data.2)
+ block0(%27 : int, %28 : Tensor) {
+ %29 : int = aten::dim(%28)
+ %data.3 : Tensor = aten::unsqueeze(%28, %29)
+ -> (%16, %data.3)
}
%cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
%cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
%res_data : Tensor = aten::where(%cond_data, %data, %a.1_data)
%res_mask : Tensor = aten::where(%cond_mask, %mask, %a.1_mask)
%res_dims : Tensor = aten::__or__(%dims, %a.1_dims)
- return (%res_data, %res_mask, %res_dims);
+ %36 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
+ return (%36);
}
%8 : Tensor = aten::gt(%a.1_data, %b_data)
%9 : Tensor = aten::mul(%a.1_mask, %b_mask)
%10 : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %11 : bool = prim::Bool(%8)
- %12 : int = prim::Constant[value=0]()
- %13 : Tensor = aten::mul(%8, %9)
- %14 : Tensor = aten::sum(%13)
- %15 : Tensor = aten::gt(%14, %12)
- %16 : bool = prim::Bool(%15)
- %17 : Tensor, %18 : Tensor, %19 : Tensor, %a : Tensor, %21 : Tensor, %22 : Tensor = prim::Loop(%7, %16, %8, %9, %10, %a.1_data, %a.1_mask, %a.1_dims)
- block0(%loop_num : int, %cond_data.2 : Tensor, %cond_mask.2 : Tensor, %cond_dims : Tensor, %6_data : Tensor, %6_mask : Tensor, %6_dims : Tensor) {
- %30 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%30)
+ %11 : int = prim::Constant[value=0]()
+ %12 : Tensor = aten::mul(%8, %9)
+ %13 : Tensor = aten::sum(%12)
+ %14 : Tensor = aten::gt(%13, %11)
+ %15 : bool = prim::Bool(%14)
+ %16 : Tensor, %17 : Tensor, %a : Tensor, %19 : Tensor, %20 : Tensor = prim::Loop(%7, %15, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
+ block0(%loop_num : int, %cond_data.2 : Tensor, %cond_mask.2 : Tensor, %6_data : Tensor, %6_mask : Tensor, %6_dims : Tensor) {
+ %27 : Long() = prim::NumToTensor(%6)
+ %alpha : float = prim::Float(%27)
%data : Tensor = aten::sub(%6_data, %b_data, %alpha)
%mask : Tensor = aten::mul(%6_mask, %b_mask)
%dims : Tensor = aten::__or__(%6_dims, %b_dims)
- %35 : Tensor = aten::gt(%data, %b_data)
- %36 : Tensor = aten::mul(%mask, %b_mask)
- %37 : Tensor = aten::__or__(%dims, %b_dims)
- %38 : bool = prim::Bool(%35)
- %39 : bool = prim::Constant[value=1]()
- %40 : int = prim::Constant[value=1]()
- %41 : Tensor = aten::type_as(%cond_mask.2, %cond_data.2)
- %data.2 : Tensor = aten::mul(%cond_data.2, %41)
- %43 : int = aten::dim(%data.2)
- %44 : bool = aten::eq(%43, %40)
- %cond_data : Tensor, %cond_mask : Tensor = prim::If(%44)
+ %32 : Tensor = aten::gt(%data, %b_data)
+ %33 : Tensor = aten::mul(%mask, %b_mask)
+ %34 : bool = prim::Constant[value=1]()
+ %35 : int = prim::Constant[value=1]()
+ %36 : Tensor = aten::type_as(%cond_mask.2, %cond_data.2)
+ %data.2 : Tensor = aten::mul(%cond_data.2, %36)
+ %38 : int = aten::dim(%data.2)
+ %39 : bool = aten::eq(%38, %35)
+ %cond_data : Tensor, %cond_mask : Tensor = prim::If(%39)
block0() {
- %47 : int = aten::dim(%data)
- %48 : int = aten::sub(%47, %40)
- %data.4 : Tensor = prim::Loop(%48, %39, %data.2)
- block0(%50 : int, %51 : Tensor) {
- %52 : int = aten::dim(%51)
- %data.3 : Tensor = aten::unsqueeze(%51, %52)
- -> (%39, %data.3)
+ %42 : int = aten::dim(%data)
+ %43 : int = aten::sub(%42, %35)
+ %data.4 : Tensor = prim::Loop(%43, %34, %data.2)
+ block0(%45 : int, %46 : Tensor) {
+ %47 : int = aten::dim(%46)
+ %data.3 : Tensor = aten::unsqueeze(%46, %47)
+ -> (%34, %data.3)
}
%cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
%cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
%res_data : Tensor = aten::where(%cond_data, %data, %6_data)
%res_mask : Tensor = aten::where(%cond_mask, %mask, %6_mask)
%res_dims : Tensor = aten::__or__(%dims, %6_dims)
- %59 : int = prim::Constant[value=0]()
- %60 : Tensor = aten::mul(%35, %36)
- %61 : Tensor = aten::sum(%60)
- %62 : Tensor = aten::gt(%61, %59)
- %63 : bool = prim::Bool(%62)
- -> (%63, %35, %36, %37, %res_data, %res_mask, %res_dims)
+ %54 : int = prim::Constant[value=0]()
+ %55 : Tensor = aten::mul(%32, %33)
+ %56 : Tensor = aten::sum(%55)
+ %57 : Tensor = aten::gt(%56, %54)
+ %58 : bool = prim::Bool(%57)
+ -> (%58, %32, %33, %res_data, %res_mask, %res_dims)
}
- return (%a, %21, %22);
+ %59 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%a, %19, %20)
+ return (%59);
}
%b_ih : Float(*)
%b_hh : Float(*)) {
%hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
- return (%hy, %cy);
+ %9 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
+ return (%9);
}
with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
%1 : Float(*)
%12 : Tensor[] = aten::broadcast_tensors(%11)
%13 : Tensor, %14 : Tensor, %15 : Tensor, %16 : Tensor = prim::ListUnpack(%12)
%17 : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
- return (%17, %cy);
+ %19 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%17, %cy)
+ return (%19);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%beta_h : Float(*)
%bias : Float(*)) {
%hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %bias, %beta_h, %beta_i, %alpha, %hx, %w_hh, %x, %w_ih)
- return (%hy, %cy);
+ %11 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
+ return (%11);
}
with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
%1 : Float(*)
def graph(self,
- a: Tensor) -> Tuple[]:
+ a: Tensor) -> None:
print("hi\016")
+ return None
while _0:
y_2 = torch.add_(y, 1, 1)
_0, y, z = bool(torch.lt(y_2, 8)), y_2, x
- return x, z
+ return (x, z)
def graph(self,
- y: Tensor) -> Tuple[]:
+ y: Tensor) -> None:
print("hi\016")
+ return None
%b : Tensor) {
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::add(%a, %b, %2)
- return (%3, %3);
+ %4 : (Tensor, Tensor) = prim::TupleConstruct(%3, %3)
+ return (%4);
}
%4 : Double(2, 2) = aten::mul(%x, %3)
%5 : Double(2, 2) = aten::mul(%x, %2)
%6 : (Double(2, 2), Double(2, 2)) = prim::TupleConstruct(%4, %5)
- return (%x, %6);
+ %7 : (Double(2, 2), (Double(2, 2), Double(2, 2))) = prim::TupleConstruct(%x, %6)
+ return (%7);
}
%a.3 : Tensor = aten::sub_(%a.2, %b, %2)
%a.4 : Tensor = aten::div_(%a.3, %b)
%a : Tensor = aten::mul_(%a.4, %b)
- return (%a, %b);
+ %7 : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
+ return (%7);
}
= prim::Print(%d, %e, %d, %5, %y.4, %5)
-> (%c.1, %y.4)
}
- return (%a, %3, %c, %5, %y);
+ %19 : (int, int, int, Tensor, Tensor) = prim::TupleConstruct(%a, %3, %c, %5, %y)
+ return (%19);
}
block1() {
-> (%x.1, %x.1, %y.1)
}
- return (%x, %y, %z);
+ %7 : (Float(*, *), Tensor, Tensor) = prim::TupleConstruct(%x, %y, %z)
+ return (%7);
}
%7 : int[] = prim::ListConstruct(%5, %6)
%8 : Tensor = aten::rand(%7, %4, %3, %2)
%a : Tensor = aten::add_(%a.1, %8, %1)
- return ();
+ %10 : None = prim::None()
+ return (%10);
}
(assert
(eq (const 1) (const 1))
(option (string_literal hello)))
- (return
- (list (variable (ident x))))))
+ (return (variable (ident x)))))
}
%8 : int = prim::TupleIndex[index=0](%b)
%9 : int = prim::TupleIndex[index=1](%b)
- return (%8, %9);
+ %10 : (int, int) = prim::TupleConstruct(%8, %9)
+ return (%10);
}
}
%c : (int, int, int, int) = prim::TupleSlice[beg=0, end=4](%b)
%e : (int, int) = prim::TupleSlice[beg=1, end=3](%c)
- %11 : int, %12 : int = prim::TupleUnpack(%e)
- return (%11, %12);
+ return (%e);
}
def get_grad_executor(plan_state, diff_graph_idx=None):
- if diff_graph_idx is None and len(list(plan_state.graph.nodes())) != 1:
- raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
+ if diff_graph_idx is None:
+ nodes = list(plan_state.graph.nodes())
+ if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"):
+ pass
+ else:
+ raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
grad_executors = list(plan_state.code.grad_executors())
return grad_executors[diff_graph_idx or 0]
ppv = "op_version_set = 0\n{}".format(pp)
sm = copy_structure_and_params(module)
torch._C._jit_import_methods(sm, ppv, constant_table)
-
pp2, _ = sm._python_print()
if pp != pp2:
self.assertMultiLineEqual(pp, pp2)
with self.assertRaisesRegex(pickle.PickleError, "not supported"):
torch.save(FooToPickle(), "will_fail")
+ def test_single_tuple_trace(self):
+ x = torch.tensor(2.)
+
+ def f2(x):
+ return (x,)
+ jit_f2 = torch.jit.trace(f2, x)
+ assert f2(x) == jit_f2(x) # fails
+
@unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
def test_restore_device_cuda(self):
class MyModule(torch.jit.ScriptModule):
self.checkScript(one_return, [a], optimize=True)
self.checkScript(multiple_returns, [a], optimize=True)
- with self.assertRaisesRegex(RuntimeError, "Expected 1 return value"):
+ with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
@torch.jit.script
def no_return_bad_annotation(a):
# type: (Tensor) -> Tensor
self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
tuple_comp = torch.jit.script(tuple_index)
self.assertExpectedGraph(tuple_comp.graph)
- self.run_pass('lower_all_tuples', tuple_comp.graph)
- m = torch.jit.ScriptModule()
- m._create_method_from_graph("forward", tuple_comp.graph)
- self.assertEqual(m(torch.tensor(1)), (1, 2))
+ self.assertEqual(tuple_comp(torch.tensor(1)), (1, 2))
with self.assertRaisesRegex(RuntimeError, "tuple indices must be integer constants"):
@torch.jit.script
return e
self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
+ tuple_graph = torch.jit.script(tuple_slice)
+ self.assertExpectedGraph(tuple_graph.graph)
+ self.run_pass('lower_all_tuples', tuple_graph.graph)
+ self.assertTrue('Tuple' not in str(tuple_graph.graph))
tuple_comp = torch.jit.script(tuple_slice)
- self.assertExpectedGraph(tuple_comp.graph)
- self.run_pass('lower_all_tuples', tuple_comp.graph)
- self.assertTrue('Tuple' not in str(tuple_comp.graph))
- m = torch.jit.ScriptModule()
- m._create_method_from_graph("forward", tuple_comp.graph)
- self.assertEqual(m(torch.tensor(1)), (2, 3))
+ self.assertEqual(tuple_comp(torch.tensor(1)), (2, 3))
@torch.jit.script
def test_indexing_end_out_of_bounds():
c = (1, 2)
return c[2:10]
- # output is None in script and () in python
- self.assertEqual(test_indexing_end_out_of_bounds(), None)
+ self.assertEqual(test_indexing_end_out_of_bounds(), ())
def test_unwrap_optional_builtin(self):
def test(x):
self.assertExpected(sm.__getattr__('forward').pretty_print_schema())
def test_annotated_script_fn_return_mismatch(self):
- with self.assertRaisesRegex(RuntimeError, r"Return value at position 0 was annotated as "
- r"having type \(Tensor, Tensor\) but is "
- r"actually of type Tensor"):
+ with self.assertRaisesRegex(RuntimeError, "but is actually of type"):
@torch.jit.script
def return_tup(x):
# type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
export_type=torch.onnx.ExportTypes.DIRECTORY)
shutil.rmtree(d)
+ def test_onnx_multiple_return(self):
+ @torch.jit.script
+ def foo(a):
+ return (a, a)
+ f = io.BytesIO()
+ x = torch.ones(3)
+ torch.onnx._export(foo, (x,), f, example_outputs=(x, x))
+
@skipIfRocm
@skipIfNoLapack
def test_aten_fallback(self):
def check_output_types(self, func, ref_outputs, args, kwargs):
graph = getattr(func, 'last_graph', None)
- if not isinstance(ref_outputs, tuple):
- ref_outputs = (ref_outputs,)
types = [o.type() for o in graph.outputs()]
- self.assertEqual(len(types), len(ref_outputs))
- for i, (t, ref_out) in enumerate(zip(types, ref_outputs)):
- if isinstance(ref_out, list):
- assert len(ref_out) > 0
- elem = ref_out[0]
- assert isinstance(elem, torch.Tensor)
- self.assertTrue(t.isSubtypeOf(torch._C.ListType.ofTensors()))
- else:
- ref_type = torch._C.Type.inferFrom(ref_out)
- self.assertTrue(ref_type.isSubtypeOf(t))
+ self.assertTrue(len(types) == 1)
+ t = types[0]
+ torch._C._jit_assert_is_instance(ref_outputs, t)
def check_against_reference(self, func, reference_func, args, kwargs=None,
expected1, expected2 = f(x, y)
self.assertEqual(result1, expected1)
self.assertEqual(result2, expected2)
- self.assertAllFused(script_f.graph_for(x, y))
+ self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
# TODO: This test seems dead
@unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows")
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
- self.assertGraphSize(graph, 2)
+ self.assertGraphSize(graph, 3)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
def test_merges_without_cycles(self):
graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
- self.assertGraphSize(graph, 1)
+ self.assertGraphSize(graph, 2)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
def test_does_not_create_cycles(self):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
- self.assertGraphSize(graph, 2)
+ self.assertGraphSize(graph, 3)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
def test_merges_down(self):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
- self.assertGraphSize(graph, 2)
+ self.assertGraphSize(graph, 3)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
def test_respects_lexical_scoping(self):
// TODO: this is a fake stub
});
+ m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) {
+ toIValue(obj, type);
+ });
initPythonIRBindings(module);
tracer::initPythonTracerBindings(module);
return;
switch (node->kind()) {
case prim::Return:
+ if (enforce_importable_ && node->inputs().size() != 1) {
+ throw script::ErrorReport(node->getSourceLocation())
+ << "Exportable methods must have a single return value. Normal use of ScriptMethods should enforce this.";
+ }
if (node->inputs().size() > 0) {
indent();
out << "return ";
return out;
}
- void printDefaultValue(std::ostream& stmt, const IValue& value) {
+ void printDefaultValue(const TypePtr& typ, std::ostream& stmt, const IValue& value) {
+ // xxx - many weak script modules store default values for broadcasting lists
+ // that are not actually the same type as the argument. We can only serialize
+ // default values that will implicitly convert to their declared return type
+ // since we do not need to serialize these built-in modules with their defaults,
+ // we just drop them for now.
+ if (typ->kind() == ListType::Kind &&
+ (value.isInt() || value.isDouble() || value.isBool())) {
+ return;
+ }
+ stmt << "=";
if (value.isTensor() && !value.toTensor().defined()) {
// XXX - because undefined tensors are not stored as None, we need special handling.
// otherwise they get printed as CONSTANTS.c0 and then cannot be recreated because
if (defaults_offset != defaults.end()) {
const c10::optional<IValue>& def = *defaults_offset++;
if (def) {
- out << "=";
- printDefaultValue(out, *def);
+ printDefaultValue(input->type(), out, *def);
}
}
}
#include <torch/csrc/jit/passes/to_batch.h>
#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
namespace torch { namespace jit {
throw std::runtime_error("function " + name + " with " + std::to_string(num_inputs) + " inputs is not supported in batched tensor yet");
}
+std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
+ auto outputs = script::inlineCallTo(g, callee, inputs);
+ if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
+ auto tc = script::createTupleUnpack(outputs.at(0));
+ outputs = std::vector<Value*>(tc.begin(), tc.end());
+ }
+ return outputs;
+}
+
// replace aten operator node with BatchTensor operator graph
void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
auto res_graph = res_block->owningGraph();
}
auto batch_graph = getBatchOperator(func_name, new_inputs.size());
- auto outputs = script::inlineCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
+ auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
// Assume all outputs from inlined operator implementation are in the triple form batched tensor or just a single non-tensor.
if (outputs.size() == 1) {
// if previous output is scalar, transform new output back to scalar from dynamic
TypePtr orig_type = n->outputs()[0]->type();
- if (orig_type != outputs[0]->type()){
+ if (!orig_type->isSubtypeOf(outputs[0]->type())) {
Symbol op;
if (orig_type == IntType::get()) {
op = prim::Int;
auto res_graph = res_block->owningGraph();
auto* r_node = res_graph->createClone(n, rn_fn);
res_block->appendNode(r_node);
- auto outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("batch_from_scalar_tensor"), r_node->outputs());
+ auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("batch_from_scalar_tensor"), r_node->outputs());
batch_map[n->output()] = outputs;
}
inputs.insert(inputs.end(), if_output.begin(), if_output.end());
auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]);
inputs.insert(inputs.end(), else_output.begin(), else_output.end());
- auto outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("where", inputs.size()), inputs);
+ auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("where", inputs.size()), inputs);
batch_map[n->outputs()[i]] = outputs;
}
}
}
if(cond_is_tensor){
auto cond = batch_map.at(n->inputs()[1]);
- auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
+ auto cond_any = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
rn_env[n->inputs()[1]] =res_graph->insert(prim::Bool, {cond_any[0]});
}
for(size_t i = 2; i < n->inputs().size(); i++){
for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
}
- outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("where"), inputs);
+ outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("where"), inputs);
}
else{
for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
}
auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
inputs.insert(inputs.end(), data.begin(), data.end());
- outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("update"), inputs);
+ outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("update"), inputs);
}
batch_map[n->outputs()[i]] = outputs;
for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
// update loop conditions
if(cond_is_tensor){
auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
- auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
+ auto cond_any = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
auto to_bool_output = res_graph->insert(prim::Bool, {cond_any[0]});
loop_block->insertOutput(0, to_bool_output);
for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
}
}
-std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph>& graph){
+std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph) {
+ // lower the tuple before the pass
+ if (graph->outputs().at(0)->type()->kind() == TupleType::Kind) {
+ graph = graph->copy();
+ auto outs = script::createTupleUnpack(graph->outputs().at(0));
+ graph->eraseOutput(0);
+ for(auto o : outs)
+ graph->registerOutput(o);
+ EliminateDeadCode(graph->block());
+ }
std::shared_ptr<Graph> res_graph = std::make_shared<Graph>();
ToBatch to_batch;
to_batch.toBatch(graph->block(), res_graph->block());
+ // methods should only have a single output, so we pack everything into a tuple
+ auto tup = res_graph->insertNode(res_graph->createTuple(res_graph->outputs()));
+ while (res_graph->outputs().size() > 0)
+ res_graph->eraseOutput(res_graph->outputs().size() - 1);
+ res_graph->registerOutput(tup->output());
+ EliminateDeadCode(res_graph->block());
+
return res_graph;
}
void initRegisterBatchOpsBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
- m.def("to_batch_graph", &to_batch_graph);
+ m.def("to_batch_graph", to_batch_graph);
m.def("register_batch_operator", [](std::string name, std::shared_ptr<Graph> graph){
ToBatch::batch_operator_table[name].push_back(graph);
});
TORCH_API void toBatch(Block* block, Block* res_block);
};
-TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph>& graph);
+TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph);
TORCH_API void initRegisterBatchOpsBindings(PyObject* module);
}}
AutoNoGIL no_gil_guard;
method.run(stack);
}
- return createPyObjectForStack(std::move(stack));
+ return toPyObject(std::move(stack.back()));
}
inline py::object invokeOperatorFromPython(
})
.def("isSubtypeOf", [](std::shared_ptr<Type>& self, std::shared_ptr<Type> other) {
return self->isSubtypeOf(other);
- })
- .def_static("inferFrom", c10::inferTypeFrom);
+ });
py::class_<NumberType, Type, std::shared_ptr<NumberType>>(m, "NumberType")
.def_static("get", &NumberType::get);
AT_ERROR("The traced function didn't return any values! Side-effects are not "
"captured in traces, so it would be a no-op.");
}
- if (!PyTuple_Check(out.ptr())) {
- out = py::make_tuple(out);
- }
- tracer::exit(toStack(out));
+ tracer::exit({toIValue(out)});
auto graph = enter_info.first->graph;
EliminateDeadCode(graph);
LowerSimpleTuples(graph);
const TypePtr& concrete_type,
Value* value,
bool allow_conversions) {
- // Allow homogeneous tuples to be casted implicitly to lists of appropriate
- // types
- if (convertibleToList(value->type(), unwrapOptional(concrete_type)) &&
- value->type()->kind() == TypeKind::TupleType) {
- auto unpacked = createTupleUnpack(value);
- auto elem_type = unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
- value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
+
+ if (auto value_tuple = value->type()->cast<TupleType>()) {
+ // Allow homogeneous tuples to be casted implicitly to lists of appropriate
+ // types
+ if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
+ auto unpacked = createTupleUnpack(value);
+ auto elem_type = unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
+ value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
+ }
+ // inductively apply implicit conversions to tuples
+ if (auto concrete_tuple = concrete_type->cast<TupleType>()) {
+ if (!value_tuple->isSubtypeOf(concrete_tuple) &&
+ concrete_tuple->elements().size() == value_tuple->elements().size()) {
+ auto unpacked = createTupleUnpack(value);
+ std::vector<Value*> converted;
+ for (size_t i = 0; i < concrete_tuple->elements().size(); ++i) {
+ converted.emplace_back(tryConvertToType(
+ loc,
+ graph,
+ concrete_tuple->elements().at(i),
+ unpacked.at(i),
+ allow_conversions));
+ }
+ value = graph.insertNode(graph.createTuple(converted))->output();
+ }
+ }
}
if (value->type()->isSubtypeOf(NoneType::get()) && !concrete_type->isSubtypeOf(NoneType::get())){
if (self && def.decl().params().size() == 0) {
throw ErrorReport(def.decl().params().range()) << "methods must have a self argument";
}
-
auto schema = extractSchemaFromDef(def);
std::vector<Argument> arguments = emitFormalArguments(self, schema);
+
// body
auto stmts = def.statements();
auto stmts_begin = stmts.begin();
return_stmt = Return(*stmts_end);
}
emitStatements(stmts_begin, stmts_end);
- std::vector<Argument> returns = emitReturn(return_stmt, schema);
+ std::vector<Argument> returns = {emitReturn(
+ return_stmt ? return_stmt->range() : def.range(), return_stmt, schema)};
method.setSchema({def.name().name(), std::move(arguments), std::move(returns)});
// remove any uses of tuples that we inserted that are not needed
Decl::create(r, List<Param>::create(r, {}), Maybe<Expr>::create(r, tuple_type));
auto tuple_expr = TupleLiteral::create(r, List<Expr>::create(r, default_exprs));
- auto ret = Return::create(r, List<Expr>::create(r, { tuple_expr }));
+ auto ret = Return::create(r, tuple_expr);
auto def = Def::create(
r,
Ident::create(r, "defaults"),
List<Stmt>::create(r, {ret}));
auto m = std::make_shared<Module>();
defineMethodsInModule(m, {def}, {resolver}, nullptr);
- m->get_method("defaults").run(default_values);
- return default_values;
+ Stack stack;
+ m->get_method("defaults").run(stack);
+ return stack.at(0).toTuple()->elements();
}
std::vector<Argument> parseArgsFromDecl(const Decl& decl) {
return retval;
}
- std::vector<Argument> parseReturnsFromDecl(const Decl& decl) {
- JIT_ASSERT(decl.return_type().present());
+ std::vector<Argument> parseReturnFromDecl(const Decl& decl) {
+ // we represent no annoation on a return type as having no values in the
+ // schema's return() list
+ // in emitReturn we take the actual return value to be the value of the return
+ // statement if no one was provided here
+ if(!decl.return_type().present())
+ return {};
+
if (handleBroadcastList(decl.return_type().get()))
throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type";
auto parsed_type = parseTypeFromExpr(decl.return_type().get());
- if (auto tuple_type = parsed_type->cast<TupleType>()) {
- // Flatten a single return type of type Tuple into its constituent types
- std::vector<Argument> retval;
- for (auto type_ptr : tuple_type->elements()) {
- retval.emplace_back(
- "",
- type_ptr,
- /*N =*/c10::nullopt,
- /*default_value =*/c10::nullopt,
- /*kwarg_only =*/false);
- }
- return retval;
- } else {
- return {Argument(
- "",
- parsed_type,
- /*N =*/c10::nullopt,
- /*default_value =*/c10::nullopt,
- /*kwarg_only =*/false)};
- }
+ return {Argument(
+ "",
+ parsed_type,
+ /*N =*/c10::nullopt,
+ /*default_value =*/c10::nullopt,
+ /*kwarg_only =*/false)};
}
FunctionSchema extractSchemaFromDef(const Def &def) {
auto name = def.name().name();
std::vector<Argument> args = parseArgsFromDecl(def.decl());
- std::vector<Argument> returns;
- bool is_varret;
- if (def.decl().return_type().present()) {
- returns = parseReturnsFromDecl(def.decl());
- is_varret = false;
- } else {
- is_varret = true;
- }
- return FunctionSchema(name, args, returns, false, is_varret);
+ std::vector<Argument> returns = parseReturnFromDecl(def.decl());
+ return FunctionSchema(name, std::move(args), std::move(returns), false, false);
}
std::vector<Argument> emitFormalArguments(const SugaredValuePtr& self, const FunctionSchema& schema) {
}
return arguments;
}
- std::vector<Argument> emitReturn(c10::optional<Return> return_stmt_, const FunctionSchema& schema) {
+
+ Argument emitReturn(const SourceRange& range, c10::optional<Return> return_stmt, const FunctionSchema& schema) {
+ JIT_ASSERT(schema.returns().size() <= 1);
// outputs
- std::vector<Argument> returns;
- if (return_stmt_) {
- auto return_stmt = *return_stmt_;
- auto results = getValues(return_stmt.values(), true);
- // a single return value that is a tuple expands in place:
- // return a
- if (return_stmt.values().size() == 1 && results.size() == 1) {
- auto result = results.at(0);
- if(result->type()->cast<TupleType>()) {
- results = createTupleUnpack(result).vec();
- }
- }
- if (!schema.is_varret() && schema.returns().size() != results.size()) {
- throw ErrorReport(def.range()) << "Number of type annotations for function"
- << " return (" << schema.returns().size() << ") does not match"
- << " the number of returns from the function (" << results.size() << ")!";
- }
- const auto& range = return_stmt.range();
- size_t return_type_idx = 0;
- for (auto r : results) {
- TypePtr type = DynamicType::get();
- if (!schema.is_varret()) {
- type = schema.returns().at(return_type_idx).type();
- r = tryConvertToType(range, *graph, type, r, /*allow_conversions=*/false);
- if (!r->type()->isSubtypeOf(type)) {
- throw ErrorReport(return_stmt.range()) << "Return value at position "
- << return_type_idx << " was annotated as having type " << type->str()
- << " but is actually of type " << r->type()->str();
- }
- return_type_idx++;
- }
- graph->registerOutput(r);
- returns.emplace_back("", type);
- }
- } else if (schema.returns().size() > 0) {
- // schema has returns but there's no return nodes in graph
- throw ErrorReport() << "Expected " << schema.returns().size()
- << " return value"
- << (schema.returns().size() > 1 ? "s" : "")
- << " but found no return statement";
+ Value* result = return_stmt ? emitExpr(return_stmt->expr())
+ : graph->insertConstant(IValue(), range);
+ TypePtr result_type = schema.returns().size() > 0
+ ? schema.returns().at(0).type()
+ : result->type();
+
+ if (return_stmt) {
+ result = tryConvertToType(
+ range, *graph, result_type, result, /*allow_conversions=*/true);
+ }
+
+ if (!result->type()->isSubtypeOf(result_type)) {
+ throw ErrorReport(range) << "Return value was annotated as having type " << result_type->python_str()
+ << " but is actually of type " << result->type()->python_str();
}
- return returns;
+ graph->registerOutput(result);
+ return Argument("", result_type);
}
void emitStatements(const List<Stmt>& statements) {
return emitStatements(statements.begin(), statements.end());
// technically this is not a python type but we need it when
// parsing serialized methods that use implicit converions to Scalar
{"number", NumberType::get()},
+ {"None", NoneType::get()},
};
return map;
}
case TK_VAR: {
return Var(expr).name().name();
}
+ case TK_NONE: {
+ return "None";
+ }
case '.': {
auto select = Select(expr);
const std::string& name = select.selector().name();
if (isTorch(select.value()) && name == "Tensor")
return "Tensor";
- }
+ } break;
}
return at::nullopt;
}
const std::string& name,
at::ArrayRef<NamedValue> kwargs);
+TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
+
} // namespace script
} // namespace jit
} // namespace torch
const std::string& script,
ResolutionCallback rcb, bool has_self) {
auto self = has_self ? std::make_shared<ModuleValue>(m) : nullptr;
- return defineMethodsInModule(m, script, pythonResolver(rcb), self);
+ defineMethodsInModule(m, script, pythonResolver(rcb), self);
})
.def("_create_methods", [](std::shared_ptr<Module> m,
const std::vector<Def>& defs,
if (self.find_method("forward")) {
Method & m = self.get_method("forward");
return m.graph_for(
- createStackForSchema(m.getSchema(), tuple_slice(std::move(args), 1), std::move(kwargs)));
+ createStackForSchema(m.getSchema(), tuple_slice(std::move(args), 1), kwargs));
}
throw std::runtime_error("Attempted to call graph_for on a Module without a compiled forward()");
})
.def("graph_for", [](py::args args, py::kwargs kwargs) {
// see: [pybind11 varargs]
Method& self = py::cast<Method&>(args[0]);
- return self.graph_for(createStackForSchema(self.getSchema(), tuple_slice(std::move(args), 1), std::move(kwargs)));
+ return self.graph_for(createStackForSchema(self.getSchema(), tuple_slice(std::move(args), 1), kwargs));
})
.def("debug_disable_autodiff_subgraph_inlining", &Method::debugDisableAutodiffSubgraphInlining)
.def("schema", &Method::getSchema)
IValue operator()(std::vector<IValue> stack) {
checkInputsAgainstSchema(stack);
run(stack);
- if (stack.size() != 1) {
- return Tuple::create(std::move(stack));
- }
return stack.front();
}
auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
retval->inputs()[i]->setType(type);
}
- JIT_ASSERT(retval->outputs().size() == outputs.size());
+ at::ArrayRef<Value*> output_values = retval->outputs();
+ // patch this to still work if we are returning a tuple of multiple values
+ if (output_values.at(0)->type()->kind() == TupleType::Kind) {
+ JIT_ASSERT(output_values.at(0)->node()->kind()== prim::TupleConstruct);
+ output_values = output_values.at(0)->node()->inputs();
+ }
+ JIT_ASSERT(output_values.size() == outputs.size());
for (size_t i=0; i < retval->outputs().size(); ++i) {
auto scalar_type = outputs[i].type().scalarType();
auto sizes = outputs[i].sizes();
auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
- retval->outputs()[i]->setType(type);
+ output_values[i]->setType(type);
}
return retval;
}
}
GraphExecutor& get_executor() {
- std::call_once(executor_init, [&]{
+ std::call_once(executor_init, [&] {
+ AT_CHECK(
+ graph()->outputs().size() == 1,
+ "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
executor = GraphExecutor(graph(), optimize);
});
return executor;
for (size_t pos = 0; pos < schema.arguments().size(); ++pos) {
const auto& argument = schema.arguments()[pos];
if (pos < inputs.size()) {
- const TypePtr inputType = inferTypeFrom(inputs[pos]);
+ // XXX - this fails to handle generic aggregates
+ // and should be replaced with a function isSubvalueOf(ivalue, type)
+ // That asks if the specific value is a valid instance of type.
+ const TypePtr inputType = incompleteInferTypeFrom(inputs[pos]);
AT_CHECK(inputType->isSubtypeOf(argument.type()),
"Expected value of type ", *argument.type(),
" for argument '", argument.name(),
}
case TK_RETURN: {
auto range = L.next().range;
- // XXX: TK_NEWLINE makes it accept an empty list
- auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &ParserImpl::parseExp);
- return Return::create(range, values);
+ Expr value = L.cur().kind != TK_NEWLINE ? parseExpOrExpTuple()
+ : Expr(c(TK_NONE, range, {}));
+ L.expect(TK_NEWLINE);
+ return Return::create(range, value);
}
case TK_RAISE: {
auto range = L.next().range;
} while(!L.nextIf(TK_DEDENT));
return c(TK_LIST, r, std::move(stmts));
}
- Decl parseDecl() {
- auto paramlist = parseList('(', ',', ')', &ParserImpl::parseParam);
- // Parse return type annotation
- TreeRef return_type;
+
+ Maybe<Expr> parseReturnAnnotation() {
if (L.nextIf(TK_ARROW)) {
// Exactly one expression for return type annotation
auto return_type_range = L.cur().range;
- return_type = Maybe<Expr>::create(return_type_range, parseExp());
+ return Maybe<Expr>::create(return_type_range, parseExp());
} else {
- // Default to returning single tensor. TODO: better sentinel value?
- return_type = Maybe<Expr>::create(L.cur().range);
+ return Maybe<Expr>::create(L.cur().range);
}
+ }
+
+ Decl parseDecl() {
+ auto paramlist = parseList('(', ',', ')', &ParserImpl::parseParam);
+ // Parse return type annotation
+ TreeRef return_type;
+ Maybe<Expr> return_annotation = parseReturnAnnotation();
L.expect(':');
- return Decl::create(paramlist.range(), List<Param>(paramlist), Maybe<Expr>(return_type));
+ return Decl::create(paramlist.range(), List<Param>(paramlist), return_annotation);
}
TreeRef parseFunction(bool is_method) {
return AugAssign::create(r, lhs, kind, rhs);
}));
py::class_<Return, Stmt>(m, "Return")
- .def(py::init([](const SourceRange& range, std::vector<Expr> values) {
- return Return::create(range, wrap_list(range, std::move(values)));
+ .def(py::init([](const SourceRange& range, Expr* value) {
+ return Return::create(range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
}));
py::class_<Raise, Stmt>(m, "Raise")
.def(py::init([](const SourceRange& range, Expr *expr) {
explicit Return(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_RETURN);
}
- List<Expr> values() const {
- return List<Expr>(subtree(0));
+ Expr expr() const {
+ return Expr(subtree(0));
}
- static Return create(const SourceRange& range, const List<Expr>& values) {
- return Return(Compound::create(TK_RETURN, range, {values}));
+ static Return create(const SourceRange& range, const Expr& value) {
+ return Return(Compound::create(TK_RETURN, range, {value}));
}
};
}
};
for (IValue& input : inputs) {
- input = add_input(input, inferTypeFrom(input), state->graph->addInput());
+ input = add_input(input, incompleteInferTypeFrom(input), state->graph->addInput());
}
return std::make_pair(state, inputs);
}
@staticmethod
def build_Return(ctx, stmt):
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("return"))
- values = (stmt.value,) if not isinstance(stmt.value, ast.Tuple) else stmt.value.elts
- return Return(r, [build_expr(ctx, val) for val in values if val is not None])
+ return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))
@staticmethod
def build_Raise(ctx, stmt):