+++ /dev/null
-graph(%x.1_data : Tensor,
- %x.1_mask : Tensor,
- %x.1_dims : Tensor,
- %y_data : Tensor,
- %y_mask : Tensor,
- %y_dims : Tensor):
- %6 : int = prim::Constant[value=1]()
- %7 : bool = prim::Constant[value=1]()
- %8 : int = prim::Constant[value=10]()
- %x : Tensor, %10 : Tensor, %11 : Tensor = prim::Loop(%8, %7, %x.1_data, %x.1_mask, %x.1_dims)
- block0(%loop_num : int, %5_data : Tensor, %5_mask : Tensor, %5_dims : Tensor):
- %16 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%16)
- %data.1 : Tensor = aten::add(%5_data, %y_data, %alpha)
- %mask : Tensor = aten::mul(%5_mask, %y_mask)
- %dims : Tensor = aten::__or__(%5_dims, %y_dims)
- %data : Tensor = aten::where(%mask, %data.1, %5_data)
- -> (%7, %data, %mask, %dims)
- %22 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%x, %10, %11)
- return (%22)
+++ /dev/null
-graph(%a.1_data : Tensor,
- %a.1_mask : Tensor,
- %a.1_dims : Tensor,
- %b_data : Tensor,
- %b_mask : Tensor,
- %b_dims : Tensor):
- %6 : int = prim::Constant[value=1]()
- %7 : Tensor = aten::gt(%a.1_data, %b_data)
- %8 : Tensor = aten::mul(%a.1_mask, %b_mask)
- %9 : Long() = prim::NumToTensor(%6)
- %alpha.1 : float = prim::Float(%9)
- %data.1 : Tensor = aten::add(%a.1_data, %b_data, %alpha.1)
- %mask.1 : Tensor = aten::mul(%a.1_mask, %b_mask)
- %dims.1 : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %14 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%14)
- %data : Tensor = aten::sub(%a.1_data, %b_data, %alpha)
- %mask : Tensor = aten::mul(%a.1_mask, %b_mask)
- %dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %19 : bool = prim::Constant[value=1]()
- %20 : int = prim::Constant[value=1]()
- %21 : Tensor = aten::type_as(%8, %7)
- %data.2 : Tensor = aten::mul(%7, %21)
- %23 : int = aten::dim(%data.2)
- %24 : bool = aten::eq(%23, %20)
- %cond_data : Tensor, %cond_mask : Tensor = prim::If(%24)
- block0():
- %27 : int = aten::dim(%data.1)
- %28 : int = aten::sub(%27, %20)
- %data.4 : Tensor = prim::Loop(%28, %19, %data.2)
- block0(%30 : int, %31 : Tensor):
- %32 : int = aten::dim(%31)
- %data.3 : Tensor = aten::unsqueeze(%31, %32)
- -> (%19, %data.3)
- %cond_data.1 : Tensor = aten::expand_as(%data.4, %data.1)
- %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask.1)
- -> (%cond_data.1, %cond_mask.1)
- block1():
- -> (%data.2, %data.2)
- %res_data : Tensor = aten::where(%cond_data, %data.1, %data)
- %res_mask : Tensor = aten::where(%cond_mask, %mask.1, %mask)
- %res_dims : Tensor = aten::__or__(%dims.1, %dims)
- %39 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
- return (%39)
+++ /dev/null
-graph(%a.1_data : Tensor,
- %a.1_mask : Tensor,
- %a.1_dims : Tensor,
- %b_data : Tensor,
- %b_mask : Tensor,
- %b_dims : Tensor):
- %6 : int = prim::Constant[value=1]()
- %7 : float = prim::Constant[value=0.1]()
- %8 : Float() = prim::NumToTensor(%7)
- %other : float = prim::Float(%8)
- %10 : Tensor = aten::gt(%a.1_data, %other)
- %11 : Long() = prim::NumToTensor(%6)
- %alpha.1 : float = prim::Float(%11)
- %data.1 : Tensor = aten::add(%a.1_data, %b_data, %alpha.1)
- %mask.1 : Tensor = aten::mul(%a.1_mask, %b_mask)
- %dims.1 : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %16 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%16)
- %data : Tensor = aten::sub(%a.1_data, %b_data, %alpha)
- %mask : Tensor = aten::mul(%a.1_mask, %b_mask)
- %dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %21 : bool = prim::Constant[value=1]()
- %22 : int = prim::Constant[value=1]()
- %23 : Tensor = aten::type_as(%a.1_mask, %10)
- %data.2 : Tensor = aten::mul(%10, %23)
- %25 : int = aten::dim(%data.2)
- %26 : bool = aten::eq(%25, %22)
- %cond_data : Tensor, %cond_mask : Tensor = prim::If(%26)
- block0():
- %29 : int = aten::dim(%data.1)
- %30 : int = aten::sub(%29, %22)
- %data.4 : Tensor = prim::Loop(%30, %21, %data.2)
- block0(%32 : int, %33 : Tensor):
- %34 : int = aten::dim(%33)
- %data.3 : Tensor = aten::unsqueeze(%33, %34)
- -> (%21, %data.3)
- %cond_data.1 : Tensor = aten::expand_as(%data.4, %data.1)
- %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask.1)
- -> (%cond_data.1, %cond_mask.1)
- block1():
- -> (%data.2, %data.2)
- %res_data : Tensor = aten::where(%cond_data, %data.1, %data)
- %res_mask : Tensor = aten::where(%cond_mask, %mask.1, %mask)
- %res_dims : Tensor = aten::__or__(%dims.1, %dims)
- %41 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
- return (%41)
+++ /dev/null
-graph(%a.1_data : Tensor,
- %a.1_mask : Tensor,
- %a.1_dims : Tensor,
- %b_data : Tensor,
- %b_mask : Tensor,
- %b_dims : Tensor):
- %6 : int = prim::Constant[value=1]()
- %7 : Tensor = aten::gt(%a.1_data, %b_data)
- %8 : Tensor = aten::mul(%a.1_mask, %b_mask)
- %9 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%9)
- %data : Tensor = aten::add(%a.1_data, %b_data, %alpha)
- %mask : Tensor = aten::mul(%a.1_mask, %b_mask)
- %dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %14 : bool = prim::Constant[value=1]()
- %15 : int = prim::Constant[value=1]()
- %16 : Tensor = aten::type_as(%8, %7)
- %data.2 : Tensor = aten::mul(%7, %16)
- %18 : int = aten::dim(%data.2)
- %19 : bool = aten::eq(%18, %15)
- %cond_data : Tensor, %cond_mask : Tensor = prim::If(%19)
- block0():
- %22 : int = aten::dim(%data)
- %23 : int = aten::sub(%22, %15)
- %data.4 : Tensor = prim::Loop(%23, %14, %data.2)
- block0(%25 : int, %26 : Tensor):
- %27 : int = aten::dim(%26)
- %data.3 : Tensor = aten::unsqueeze(%26, %27)
- -> (%14, %data.3)
- %cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
- %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
- -> (%cond_data.1, %cond_mask.1)
- block1():
- -> (%data.2, %data.2)
- %res_data : Tensor = aten::where(%cond_data, %data, %a.1_data)
- %res_mask : Tensor = aten::where(%cond_mask, %mask, %a.1_mask)
- %res_dims : Tensor = aten::__or__(%dims, %a.1_dims)
- %34 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
- return (%34)
+++ /dev/null
-graph(%a.1_data : Tensor,
- %a.1_mask : Tensor,
- %a.1_dims : Tensor,
- %b_data : Tensor,
- %b_mask : Tensor,
- %b_dims : Tensor):
- %6 : int = prim::Constant[value=1]()
- %7 : float = prim::Constant[value=0.1]()
- %8 : Float() = prim::NumToTensor(%7)
- %other : float = prim::Float(%8)
- %10 : Tensor = aten::gt(%a.1_data, %other)
- %11 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%11)
- %data : Tensor = aten::add(%a.1_data, %b_data, %alpha)
- %mask : Tensor = aten::mul(%a.1_mask, %b_mask)
- %dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %16 : bool = prim::Constant[value=1]()
- %17 : int = prim::Constant[value=1]()
- %18 : Tensor = aten::type_as(%a.1_mask, %10)
- %data.2 : Tensor = aten::mul(%10, %18)
- %20 : int = aten::dim(%data.2)
- %21 : bool = aten::eq(%20, %17)
- %cond_data : Tensor, %cond_mask : Tensor = prim::If(%21)
- block0():
- %24 : int = aten::dim(%data)
- %25 : int = aten::sub(%24, %17)
- %data.4 : Tensor = prim::Loop(%25, %16, %data.2)
- block0(%27 : int, %28 : Tensor):
- %29 : int = aten::dim(%28)
- %data.3 : Tensor = aten::unsqueeze(%28, %29)
- -> (%16, %data.3)
- %cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
- %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
- -> (%cond_data.1, %cond_mask.1)
- block1():
- -> (%data.2, %data.2)
- %res_data : Tensor = aten::where(%cond_data, %data, %a.1_data)
- %res_mask : Tensor = aten::where(%cond_mask, %mask, %a.1_mask)
- %res_dims : Tensor = aten::__or__(%dims, %a.1_dims)
- %36 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
- return (%36)
+++ /dev/null
-graph(%a.1_data : Tensor,
- %a.1_mask : Tensor,
- %a.1_dims : Tensor,
- %b_data : Tensor,
- %b_mask : Tensor,
- %b_dims : Tensor):
- %6 : int = prim::Constant[value=1]()
- %7 : int = prim::Constant[value=9223372036854775807]()
- %8 : Tensor = aten::gt(%a.1_data, %b_data)
- %9 : Tensor = aten::mul(%a.1_mask, %b_mask)
- %10 : Tensor = aten::__or__(%a.1_dims, %b_dims)
- %11 : int = prim::Constant[value=0]()
- %12 : Tensor = aten::mul(%8, %9)
- %13 : Tensor = aten::sum(%12)
- %14 : Tensor = aten::gt(%13, %11)
- %15 : bool = prim::Bool(%14)
- %16 : Tensor, %17 : Tensor, %a : Tensor, %19 : Tensor, %20 : Tensor = prim::Loop(%7, %15, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
- block0(%loop_num : int, %cond_data.2 : Tensor, %cond_mask.2 : Tensor, %6_data : Tensor, %6_mask : Tensor, %6_dims : Tensor):
- %27 : Long() = prim::NumToTensor(%6)
- %alpha : float = prim::Float(%27)
- %data : Tensor = aten::sub(%6_data, %b_data, %alpha)
- %mask : Tensor = aten::mul(%6_mask, %b_mask)
- %dims : Tensor = aten::__or__(%6_dims, %b_dims)
- %32 : Tensor = aten::gt(%data, %b_data)
- %33 : Tensor = aten::mul(%mask, %b_mask)
- %34 : bool = prim::Constant[value=1]()
- %35 : int = prim::Constant[value=1]()
- %36 : Tensor = aten::type_as(%cond_mask.2, %cond_data.2)
- %data.2 : Tensor = aten::mul(%cond_data.2, %36)
- %38 : int = aten::dim(%data.2)
- %39 : bool = aten::eq(%38, %35)
- %cond_data : Tensor, %cond_mask : Tensor = prim::If(%39)
- block0():
- %42 : int = aten::dim(%data)
- %43 : int = aten::sub(%42, %35)
- %data.4 : Tensor = prim::Loop(%43, %34, %data.2)
- block0(%45 : int, %46 : Tensor):
- %47 : int = aten::dim(%46)
- %data.3 : Tensor = aten::unsqueeze(%46, %47)
- -> (%34, %data.3)
- %cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
- %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
- -> (%cond_data.1, %cond_mask.1)
- block1():
- -> (%data.2, %data.2)
- %res_data : Tensor = aten::where(%cond_data, %data, %6_data)
- %res_mask : Tensor = aten::where(%cond_mask, %mask, %6_mask)
- %res_dims : Tensor = aten::__or__(%dims, %6_dims)
- %54 : int = prim::Constant[value=0]()
- %55 : Tensor = aten::mul(%32, %33)
- %56 : Tensor = aten::sum(%55)
- %57 : Tensor = aten::gt(%56, %54)
- %58 : bool = prim::Bool(%57)
- -> (%58, %32, %33, %res_data, %res_mask, %res_dims)
- %59 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%a, %19, %20)
- return (%59)
+++ /dev/null
-graph(%x : Float(*, *),
- %scale : Float(*),
- %shift : Float(*)):
- %3 : Float(*, *) = prim::FusionGroup_0(%shift, %x, %scale)
- return (%3)
-with prim::FusionGroup_0 = graph(%0 : Float(*),
- %1 : Float(*, *),
- %2 : Float(*)):
- %3 : int = prim::Constant[value=1]()
- %4 : Float(*, *) = aten::mul(%1, %2)
- %5 : Float(*, *) = aten::add(%4, %0, %3)
- return (%5)
+++ /dev/null
-graph(%x : Float(*, *)):
- %1 : Float(*, *) = prim::FusionGroup_0(%x)
- return (%1)
-with prim::FusionGroup_0 = graph(%0 : Float(*, *)):
- %1 : Float(*, *), %2 : Float(*, *), %3 : Float(*, *) = prim::ConstantChunk[chunks=3, dim=1](%0)
- %4 : int = prim::Constant[value=1]()
- %5 : Float(*, *) = aten::mul(%1, %2)
- %6 : Float(*, *) = aten::add(%5, %3, %4)
- return (%6)
+++ /dev/null
-graph(%x : Float(*, *),
- %y : Float(*, *)):
- %2 : Tensor[] = prim::ListConstruct(%x, %y)
- %3 : Tensor[] = aten::broadcast_tensors(%2)
- %4 : Tensor, %5 : Tensor = prim::ListUnpack(%3)
- %6 : Float(*, *) = prim::FusionGroup_0(%5, %4)
- return (%6)
-with prim::FusionGroup_0 = graph(%0 : Tensor,
- %1 : Tensor):
- %2 : Float(*, *), %3 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%1)
- %4 : Float(*, *), %5 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%0)
- %6 : int = prim::Constant[value=1]()
- %7 : Float(*, *) = aten::add(%2, %4, %6)
- %8 : Float(*, *) = aten::add(%3, %5, %6)
- %9 : Float(*, *) = aten::mul(%7, %8)
- return (%9)
+++ /dev/null
-graph(%s : Float(*, *, *),
- %x : Float(*, *, *),
- %y : Float(*, *, *),
- %z : Float(*, *, *)):
- %4 : Float(*, *, *) = prim::FusionGroup_0(%s, %y, %x, %z)
- return (%4)
-with prim::FusionGroup_0 = graph(%0 : Float(*, *, *),
- %1 : Float(*, *, *),
- %2 : Float(*, *, *),
- %3 : Float(*, *, *)):
- %4 : Float(*, *, *), %5 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=2](%3)
- %6 : Float(*, *, *), %7 : Float(*, *, *), %8 : Float(*, *, *) = prim::ConstantChunk[chunks=3, dim=1](%2)
- %9 : Float(*, *, *), %10 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=0](%1)
- %11 : int = prim::Constant[value=1]()
- %12 : Float(*, *, *) = aten::add(%0, %6, %11)
- %13 : Float(*, *, *) = aten::add(%12, %7, %11)
- %14 : Float(*, *, *) = aten::add(%13, %8, %11)
- %15 : Float(*, *, *) = aten::add(%14, %9, %11)
- %16 : Float(*, *, *) = aten::add(%15, %10, %11)
- %17 : Float(*, *, *) = aten::add(%16, %4, %11)
- %18 : Float(*, *, *) = aten::add(%17, %5, %11)
- return (%18)
+++ /dev/null
-graph(%hx : Float(*, *),
- %cx : Float(*, *)):
- %2 : Float(*, *) = prim::FusionGroup_0(%hx, %cx)
- return (%2)
-with prim::FusionGroup_0 = graph(%0 : Float(*, *),
- %1 : Float(*, *)):
- %2 : int = prim::Constant[value=1]()
- %3 : Float(*, *) = aten::add(%0, %1, %2)
- %4 : Float(*, *) = aten::mul(%0, %1)
- %5 : Float(*, *) = prim::FusedConcat[dim=0](%3, %4)
- return (%5)
+++ /dev/null
-graph(%x : Float(*, *),
- %y : Float(*, *),
- %z : Float(*, *)):
- %3 : int = prim::Constant[value=1]()
- %w : Float(*, *) = prim::FusionGroup_0(%x, %y)
- %5 : Float(*, *) = aten::add(%w, %z, %3)
- return (%5)
-with prim::FusionGroup_0 = graph(%0 : Float(*, *),
- %1 : Float(*, *)):
- %2 : int = prim::Constant[value=1]()
- %x1 : Float(*, *) = aten::add(%0, %1, %2)
- %y1 : Float(*, *) = aten::sub(%0, %1, %2)
- %w : Float(*, *) = prim::FusedConcat[dim=0](%x1, %y1)
- return (%w)
+++ /dev/null
-graph(%input : Float(*, *),
- %input0 : Float(*, *),
- %cx : Float(*, *),
- %weight : Float(*, *),
- %weight0 : Float(*, *),
- %bias : Float(*),
- %bias0 : Float(*)):
- %7 : Float(*, *) = aten::t(%weight)
- %8 : Float(*, *) = aten::mm(%input, %7)
- %9 : Float(*, *) = aten::t(%weight0)
- %10 : Float(*, *) = aten::mm(%input0, %9)
- %11 : Tensor[] = prim::ListConstruct(%bias, %8, %bias0, %10)
- %12 : Tensor[] = aten::broadcast_tensors(%11)
- %13 : Tensor, %14 : Tensor, %15 : Tensor, %16 : Tensor = prim::ListUnpack(%12)
- %17 : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
- return (%17)
-with prim::FusionGroup_0 = graph(%0 : Float(*, *),
- %1 : Tensor,
- %2 : Tensor,
- %3 : Tensor,
- %4 : Tensor):
- %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
- %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
- %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
- %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
- %21 : int = prim::Constant[value=1]()
- %22 : Float(*, *) = aten::add(%13, %17, %21)
- %23 : Float(*, *) = aten::add(%14, %18, %21)
- %24 : Float(*, *) = aten::add(%15, %19, %21)
- %25 : Float(*, *) = aten::add(%16, %20, %21)
- %26 : Float(*, *) = aten::add(%5, %9, %21)
- %27 : Float(*, *) = aten::add(%6, %10, %21)
- %28 : Float(*, *) = aten::add(%7, %11, %21)
- %29 : Float(*, *) = aten::add(%8, %12, %21)
- %30 : Float(*, *) = aten::add(%26, %22, %21)
- %31 : Float(*, *) = aten::add(%27, %23, %21)
- %32 : Float(*, *) = aten::add(%28, %24, %21)
- %33 : Float(*, *) = aten::add(%29, %25, %21)
- %ingate0 : Float(*, *) = aten::sigmoid(%30)
- %forgetgate0 : Float(*, *) = aten::sigmoid(%31)
- %cellgate0 : Float(*, *) = aten::tanh(%32)
- %outgate0 : Float(*, *) = aten::sigmoid(%33)
- %38 : Float(*, *) = aten::mul(%forgetgate0, %0)
- %39 : Float(*, *) = aten::mul(%ingate0, %cellgate0)
- %cy : Float(*, *) = aten::add(%38, %39, %21)
- %41 : Float(*, *) = aten::tanh(%cy)
- %hy : Float(*, *) = aten::mul(%outgate0, %41)
- %43 : Float(*, *) = prim::FusedConcat[dim=0](%hy, %cy)
- return (%43)
+++ /dev/null
-graph(%0 : Float(*, *),
- %1 : Float(*, *),
- %2 : UndefinedTensor,
- %3 : UndefinedTensor,
- %4 : UndefinedTensor,
- %5 : UndefinedTensor,
- %6 : UndefinedTensor,
- %7 : UndefinedTensor,
- %8 : UndefinedTensor,
- %9 : Float(*, *),
- %10 : Float(*, *),
- %11 : Float(*, *),
- %12 : Float(*, *),
- %13 : Float(*, *),
- %14 : int[],
- %15 : int[],
- %16 : int[],
- %17 : int[],
- %18 : int[],
- %19 : int[],
- %ingate : Float(*, *),
- %forgetgate : Float(*, *),
- %cellgate : Float(*, *),
- %outgate : Float(*, *),
- %24 : int[],
- %25 : int[],
- %26 : Float(*, *)):
- %27 : int = prim::Constant[value=1]()
- %28 : int[] = aten::size(%outgate)
- %29 : int[] = aten::size(%26)
- %30 : int[] = aten::size(%ingate)
- %31 : int[] = aten::size(%cellgate)
- %32 : int[] = aten::size(%forgetgate)
- %33 : int[] = aten::size(%9)
- %34 : Tensor = prim::FusionGroup_0(%outgate, %0, %26, %28)
- %grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %26, %0, %outgate, %33, %32, %24, %31, %30, %25, %29)
- %39 : Tensor[] = prim::ListConstruct(%38, %36, %37, %34)
- %40 : Tensor = aten::cat(%39, %27)
- %41 : Tensor = aten::_grad_sum_to_size(%40, %19)
- %42 : Tensor = aten::_grad_sum_to_size(%40, %17)
- %43 : Tensor = aten::_grad_sum_to_size(%40, %14)
- %44 : Tensor = aten::_grad_sum_to_size(%40, %15)
- %45 : Float(*, *) = aten::t(%13)
- %grad_self.7 : Float(*, *) = aten::mm(%44, %45)
- %47 : Float(*, *) = aten::t(%10)
- %grad_mat2.1 : Float(*, *) = aten::mm(%47, %44)
- %grad_self.9 : Float(*, *) = aten::t(%grad_mat2.1)
- %50 : Float(*, *) = aten::t(%12)
- %grad_self.11 : Float(*, *) = aten::mm(%43, %50)
- %52 : Float(*, *) = aten::t(%11)
- %grad_mat2.3 : Float(*, *) = aten::mm(%52, %43)
- %grad_self.13 : Float(*, *) = aten::t(%grad_mat2.3)
- return (%grad_other.5, %41, %42, %grad_self.7, %grad_self.9, %grad_self.11, %grad_self.13)
-with prim::FusionGroup_0 = graph(%0 : Float(*, *),
- %1 : Float(*, *),
- %2 : Float(*, *),
- %3 : int[]):
- %4 : int = prim::Constant[value=1]()
- %5 : Float(*, *) = aten::mul(%1, %2)
- %grad_self.1 : Tensor = aten::_grad_sum_to_size(%5, %3)
- %7 : Float(*, *) = aten::neg(%0)
- %8 : Float(*, *) = aten::add(%7, %4, %4)
- %9 : Float(*, *) = aten::mul(%8, %0)
- %10 : Tensor = aten::mul(%9, %grad_self.1)
- return (%10)
-with prim::FusionGroup_1 = graph(%0 : Float(*, *),
- %1 : Float(*, *),
- %2 : Float(*, *),
- %3 : Float(*, *),
- %4 : Float(*, *),
- %5 : Float(*, *),
- %6 : Float(*, *),
- %7 : Float(*, *),
- %8 : int[],
- %9 : int[],
- %10 : int[],
- %11 : int[],
- %12 : int[],
- %13 : int[],
- %14 : int[]):
- %15 : int = prim::Constant[value=1]()
- %16 : Float(*, *) = aten::neg(%0)
- %17 : Float(*, *) = aten::add(%16, %15, %15)
- %18 : Float(*, *) = aten::mul(%17, %0)
- %19 : Float(*, *) = aten::mul(%3, %3)
- %20 : Float(*, *) = aten::neg(%19)
- %21 : Float(*, *) = aten::add(%20, %15, %15)
- %22 : Float(*, *) = aten::mul(%6, %7)
- %grad_other.1 : Tensor = aten::_grad_sum_to_size(%22, %14)
- %24 : Float(*, *) = aten::mul(%5, %5)
- %25 : Float(*, *) = aten::neg(%24)
- %26 : Float(*, *) = aten::add(%25, %15, %15)
- %27 : Tensor = aten::mul(%grad_other.1, %26)
- %28 : Tensor = aten::add(%4, %27, %15)
- %29 : Tensor = aten::_grad_sum_to_size(%28, %13)
- %30 : Tensor = aten::mul(%29, %3)
- %grad_self.3 : Tensor = aten::_grad_sum_to_size(%30, %12)
- %32 : Float(*, *) = aten::neg(%2)
- %33 : Float(*, *) = aten::add(%32, %15, %15)
- %34 : Float(*, *) = aten::mul(%33, %2)
- %35 : Tensor = aten::mul(%34, %grad_self.3)
- %36 : Tensor = aten::mul(%29, %2)
- %grad_other.3 : Tensor = aten::_grad_sum_to_size(%36, %11)
- %38 : Tensor = aten::mul(%grad_other.3, %21)
- %39 : Tensor = aten::_grad_sum_to_size(%28, %10)
- %40 : Tensor = aten::mul(%39, %1)
- %grad_self.5 : Tensor = aten::_grad_sum_to_size(%40, %9)
- %42 : Tensor = aten::mul(%18, %grad_self.5)
- %43 : Tensor = aten::mul(%39, %0)
- %grad_other.5 : Tensor = aten::_grad_sum_to_size(%43, %8)
- return (%grad_other.5, %42, %38, %35)
+++ /dev/null
-graph(%x : Float(*, *),
- %hx : Float(*, *),
- %cx : Float(*, *),
- %w_ih : Float(*, *),
- %w_hh : Float(*, *),
- %b_ih : Float(*),
- %b_hh : Float(*)):
- %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
- %9 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
- return (%9)
-with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *),
- %1 : Float(*),
- %2 : Float(*),
- %3 : Float(*, *),
- %4 : Float(*, *),
- %5 : Float(*, *),
- %6 : Float(*, *)):
- %7 : Float(*, *) = aten::t(%6)
- %8 : Float(*, *) = aten::mm(%5, %7)
- %9 : Float(*, *) = aten::t(%4)
- %10 : Float(*, *) = aten::mm(%3, %9)
- %11 : int[] = aten::size(%8)
- %12 : int[] = aten::size(%10)
- %13 : int[] = aten::size(%2)
- %14 : int[] = aten::size(%1)
- %15 : Tensor[] = prim::ListConstruct(%1, %2, %8, %10)
- %16 : Tensor[] = aten::broadcast_tensors(%15)
- %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
- %21 : int[] = prim::BroadcastSizes(%11, %12)
- %22 : int[] = prim::BroadcastSizes(%21, %13)
- %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
- %30 : int[] = aten::size(%0)
- %31 : int[] = aten::size(%cellgate.1)
- %32 : int[] = aten::size(%forgetgate.1)
- %33 : int[] = aten::size(%ingate.1)
- %34 : int[] = prim::BroadcastSizes(%32, %30)
- %35 : int[] = prim::BroadcastSizes(%33, %31)
- return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24)
-with prim::FusionGroup_0 = graph(%0 : Float(*, *),
- %1 : Tensor,
- %2 : Tensor,
- %3 : Tensor,
- %4 : Tensor):
- %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
- %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
- %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
- %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
- %21 : int = prim::Constant[value=1]()
- %22 : Float(*, *) = aten::add(%13, %17, %21)
- %23 : Float(*, *) = aten::add(%14, %18, %21)
- %24 : Float(*, *) = aten::add(%15, %19, %21)
- %25 : Float(*, *) = aten::add(%16, %20, %21)
- %26 : Float(*, *) = aten::add(%22, %9, %21)
- %27 : Float(*, *) = aten::add(%23, %10, %21)
- %28 : Float(*, *) = aten::add(%24, %11, %21)
- %29 : Float(*, *) = aten::add(%25, %12, %21)
- %30 : Float(*, *) = aten::add(%26, %5, %21)
- %31 : Float(*, *) = aten::add(%27, %6, %21)
- %32 : Float(*, *) = aten::add(%28, %7, %21)
- %33 : Float(*, *) = aten::add(%29, %8, %21)
- %ingate.1 : Float(*, *) = aten::sigmoid(%30)
- %forgetgate.1 : Float(*, *) = aten::sigmoid(%31)
- %cellgate.1 : Float(*, *) = aten::tanh(%32)
- %outgate.1 : Float(*, *) = aten::sigmoid(%33)
- %38 : Float(*, *) = aten::mul(%forgetgate.1, %0)
- %39 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
- %cy : Float(*, *) = aten::add(%38, %39, %21)
- %41 : Float(*, *) = aten::tanh(%cy)
- %hy : Float(*, *) = aten::mul(%outgate.1, %41)
- return (%hy, %41, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1)
+++ /dev/null
-graph(%0 : Float(3, 10),
- %1 : Float(3, 20),
- %2 : Float(3, 20),
- %3 : Float(80, 10),
- %4 : Float(80, 20),
- %5 : Float(80),
- %6 : Float(80)):
- %7 : Float(10!, 80!) = aten::t(%3)
- %8 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=0](%5)
- %9 : Float(3, 80) = aten::addmm[alpha={1}, beta={1}](%8, %0, %7)
- %10 : Float(20!, 80!) = aten::t(%4)
- %11 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=0](%6)
- %12 : Float(3, 80) = aten::addmm[alpha={1}, beta={1}](%11, %1, %10)
- %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%9)
- %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%12)
- %21 : Float(3, 20), %22 : Float(3, 20) = prim::FusionGroup_0[device=-1](%2, %16, %20, %15, %19, %14, %18, %13, %17)
- return (%21, %22)
-with prim::FusionGroup_0 = graph(%12 : Float(3, 20),
- %22 : Float(3!, 20),
- %23 : Float(3!, 20),
- %25 : Float(3!, 20),
- %26 : Float(3!, 20),
- %28 : Float(3!, 20),
- %29 : Float(3!, 20),
- %31 : Float(3!, 20),
- %32 : Float(3!, 20)):
- %33 : Float(3, 20) = aten::add[alpha={1}](%31, %32)
- %30 : Float(3, 20) = aten::add[alpha={1}](%28, %29)
- %27 : Float(3, 20) = aten::add[alpha={1}](%25, %26)
- %24 : Float(3, 20) = aten::add[alpha={1}](%22, %23)
- %21 : Float(3, 20) = aten::sigmoid(%33)
- %19 : Float(3, 20) = aten::sigmoid(%30)
- %17 : Float(3, 20) = aten::tanh(%27)
- %15 : Float(3, 20) = aten::sigmoid(%24)
- %13 : Float(3, 20) = aten::mul(%19, %12)
- %10 : Float(3, 20) = aten::mul(%21, %17)
- %7 : Float(3, 20) = aten::add[alpha={1}](%13, %10)
- %4 : Float(3, 20) = aten::tanh(%7)
- %2 : Float(3, 20) = aten::mul(%15, %4)
- return (%2, %7)
+++ /dev/null
-graph(%input : Float(*, *),
- %input0 : Float(*, *),
- %cx : Float(*, *),
- %weight : Float(*, *),
- %weight0 : Float(*, *),
- %bias : Float(*),
- %bias0 : Float(*)):
- %7 : Float(*, *) = aten::t(%weight)
- %8 : Float(*, *) = aten::mm(%input, %7)
- %9 : Float(*, *) = aten::t(%weight0)
- %10 : Float(*, *) = aten::mm(%input0, %9)
- %11 : Tensor[] = prim::ListConstruct(%bias, %8, %bias0, %10)
- %12 : Tensor[] = aten::broadcast_tensors(%11)
- %13 : Tensor, %14 : Tensor, %15 : Tensor, %16 : Tensor = prim::ListUnpack(%12)
- %17 : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
- %19 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%17, %cy)
- return (%19)
-with prim::FusionGroup_0 = graph(%0 : Float(*, *),
- %1 : Tensor,
- %2 : Tensor,
- %3 : Tensor,
- %4 : Tensor):
- %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
- %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
- %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
- %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
- %21 : int = prim::Constant[value=1]()
- %22 : Float(*, *) = aten::add(%13, %17, %21)
- %23 : Float(*, *) = aten::add(%14, %18, %21)
- %24 : Float(*, *) = aten::add(%15, %19, %21)
- %25 : Float(*, *) = aten::add(%16, %20, %21)
- %26 : Float(*, *) = aten::add(%5, %9, %21)
- %27 : Float(*, *) = aten::add(%6, %10, %21)
- %28 : Float(*, *) = aten::add(%7, %11, %21)
- %29 : Float(*, *) = aten::add(%8, %12, %21)
- %30 : Float(*, *) = aten::add(%26, %22, %21)
- %31 : Float(*, *) = aten::add(%27, %23, %21)
- %32 : Float(*, *) = aten::add(%28, %24, %21)
- %33 : Float(*, *) = aten::add(%29, %25, %21)
- %ingate0 : Float(*, *) = aten::sigmoid(%30)
- %forgetgate0 : Float(*, *) = aten::sigmoid(%31)
- %cellgate0 : Float(*, *) = aten::tanh(%32)
- %outgate0 : Float(*, *) = aten::sigmoid(%33)
- %38 : Float(*, *) = aten::mul(%forgetgate0, %0)
- %39 : Float(*, *) = aten::mul(%ingate0, %cellgate0)
- %cy : Float(*, *) = aten::add(%38, %39, %21)
- %41 : Float(*, *) = aten::tanh(%cy)
- %42 : Float(*, *) = aten::mul(%outgate0, %41)
- return (%42, %cy)
+++ /dev/null
-graph(%0 : Float(*, *),
- %1 : Float(*, *),
- %2 : UndefinedTensor,
- %3 : UndefinedTensor,
- %4 : UndefinedTensor,
- %5 : UndefinedTensor,
- %6 : UndefinedTensor,
- %7 : UndefinedTensor,
- %8 : UndefinedTensor,
- %9 : UndefinedTensor,
- %10 : Float(*, *),
- %11 : Float(*),
- %12 : Float(*),
- %13 : Float(*),
- %14 : Float(*, *),
- %15 : Float(*, *),
- %Wx : Float(*, *),
- %Uz : Float(*, *),
- %18 : Float(*, *),
- %19 : int[],
- %20 : int[],
- %21 : int[],
- %22 : int[],
- %23 : int[],
- %24 : int[],
- %ingate : Float(*, *),
- %forgetgate : Float(*, *),
- %cellgate : Float(*, *),
- %outgate : Float(*, *),
- %29 : int[],
- %30 : int[],
- %31 : Float(*, *)):
- %32 : int = prim::Constant[value=1]()
- %33 : int[] = aten::size(%outgate)
- %34 : int[] = aten::size(%31)
- %35 : int[] = aten::size(%ingate)
- %36 : int[] = aten::size(%cellgate)
- %37 : int[] = aten::size(%forgetgate)
- %38 : Tensor = prim::FusionGroup_0(%outgate, %0, %31, %33)
- %39 : Tensor, %40 : Tensor, %41 : Tensor = prim::FusionGroup_1(%10, %ingate, %cellgate, %1, %31, %0, %outgate, %forgetgate, %37, %29, %36, %35, %30, %34)
- %42 : Tensor[] = prim::ListConstruct(%41, %39, %40, %38)
- %43 : Tensor = aten::cat(%42, %32)
- %44 : Tensor = aten::_grad_sum_to_size(%43, %24)
- %45 : Tensor = aten::_grad_sum_to_size(%43, %22)
- %46 : int[] = aten::size(%11)
- %grad_self.7 : Tensor = prim::FusionGroup_2(%45, %Uz, %46)
- %48 : int[] = aten::size(%Uz)
- %49 : Tensor = aten::_grad_sum_to_size(%43, %19)
- %50 : Tensor = aten::_grad_sum_to_size(%43, %20)
- %51 : int[] = aten::size(%12)
- %grad_self.9 : Tensor = prim::FusionGroup_3(%50, %Wx, %51)
- %53 : int[] = aten::size(%Wx)
- %54 : int[] = aten::size(%18)
- %55 : Tensor = prim::FusionGroup_4(%49, %18, %45, %11, %48)
- %56 : int[] = aten::size(%13)
- %grad_self.13 : Tensor, %58 : Tensor = prim::FusionGroup_5(%Wx, %13, %49, %Uz, %50, %12, %56, %53, %54)
- %59 : Float(*, *) = aten::t(%14)
- %grad_mat2.1 : Float(*, *) = aten::mm(%59, %55)
- %grad_self.17 : Float(*, *) = aten::t(%grad_mat2.1)
- %62 : Float(*, *) = aten::t(%15)
- %grad_mat2.3 : Float(*, *) = aten::mm(%62, %58)
- %grad_self.21 : Float(*, *) = aten::t(%grad_mat2.3)
- return (%44, %grad_self.7, %grad_self.9, %grad_self.13, %grad_self.17, %grad_self.21)
-with prim::FusionGroup_0 = graph(%0 : Float(*, *),
- %1 : Float(*, *),
- %2 : Float(*, *),
- %3 : int[]):
- %4 : int = prim::Constant[value=1]()
- %5 : Float(*, *) = aten::mul(%1, %2)
- %grad_self.1 : Tensor = aten::_grad_sum_to_size(%5, %3)
- %7 : Float(*, *) = aten::neg(%0)
- %8 : Float(*, *) = aten::add(%7, %4, %4)
- %9 : Float(*, *) = aten::mul(%8, %0)
- %10 : Tensor = aten::mul(%9, %grad_self.1)
- return (%10)
-with prim::FusionGroup_1 = graph(%0 : Float(*, *),
- %1 : Float(*, *),
- %2 : Float(*, *),
- %3 : Float(*, *),
- %4 : Float(*, *),
- %5 : Float(*, *),
- %6 : Float(*, *),
- %7 : Float(*, *),
- %8 : int[],
- %9 : int[],
- %10 : int[],
- %11 : int[],
- %12 : int[],
- %13 : int[]):
- %14 : int = prim::Constant[value=1]()
- %15 : Float(*, *) = aten::neg(%7)
- %16 : Float(*, *) = aten::add(%15, %14, %14)
- %17 : Float(*, *) = aten::mul(%16, %7)
- %18 : Float(*, *) = aten::mul(%2, %2)
- %19 : Float(*, *) = aten::neg(%18)
- %20 : Float(*, *) = aten::add(%19, %14, %14)
- %21 : Float(*, *) = aten::mul(%5, %6)
- %grad_other.1 : Tensor = aten::_grad_sum_to_size(%21, %13)
- %23 : Float(*, *) = aten::mul(%4, %4)
- %24 : Float(*, *) = aten::neg(%23)
- %25 : Float(*, *) = aten::add(%24, %14, %14)
- %26 : Tensor = aten::mul(%grad_other.1, %25)
- %27 : Tensor = aten::add(%3, %26, %14)
- %28 : Tensor = aten::_grad_sum_to_size(%27, %12)
- %29 : Tensor = aten::mul(%28, %2)
- %grad_self.3 : Tensor = aten::_grad_sum_to_size(%29, %11)
- %31 : Float(*, *) = aten::neg(%1)
- %32 : Float(*, *) = aten::add(%31, %14, %14)
- %33 : Float(*, *) = aten::mul(%32, %1)
- %34 : Tensor = aten::mul(%33, %grad_self.3)
- %35 : Tensor = aten::mul(%28, %1)
- %grad_other.3 : Tensor = aten::_grad_sum_to_size(%35, %10)
- %37 : Tensor = aten::mul(%grad_other.3, %20)
- %38 : Tensor = aten::_grad_sum_to_size(%27, %9)
- %39 : Tensor = aten::mul(%38, %0)
- %grad_self.5 : Tensor = aten::_grad_sum_to_size(%39, %8)
- %41 : Tensor = aten::mul(%17, %grad_self.5)
- return (%41, %37, %34)
-with prim::FusionGroup_2 = graph(%0 : Tensor,
- %1 : Float(*, *),
- %2 : int[]):
- %3 : Tensor = aten::mul(%0, %1)
- %grad_self.7 : Tensor = aten::_grad_sum_to_size(%3, %2)
- return (%grad_self.7)
-with prim::FusionGroup_3 = graph(%0 : Tensor,
- %1 : Float(*, *),
- %2 : int[]):
- %3 : Tensor = aten::mul(%0, %1)
- %grad_self.9 : Tensor = aten::_grad_sum_to_size(%3, %2)
- return (%grad_self.9)
-with prim::FusionGroup_4 = graph(%0 : Tensor,
- %1 : Float(*, *),
- %2 : Tensor,
- %3 : Float(*),
- %4 : int[]):
- %5 : int = prim::Constant[value=1]()
- %6 : Tensor = aten::mul(%2, %3)
- %grad_other.7 : Tensor = aten::_grad_sum_to_size(%6, %4)
- %8 : Tensor = aten::mul(%0, %1)
- %grad_other.11 : Tensor = aten::_grad_sum_to_size(%8, %4)
- %10 : Tensor = aten::add(%grad_other.7, %grad_other.11, %5)
- return (%10)
-with prim::FusionGroup_5 = graph(%0 : Float(*, *),
- %1 : Float(*),
- %2 : Tensor,
- %3 : Float(*, *),
- %4 : Tensor,
- %5 : Float(*),
- %6 : int[],
- %7 : int[],
- %8 : int[]):
- %9 : int = prim::Constant[value=1]()
- %10 : Tensor = aten::mul(%4, %5)
- %grad_other.9 : Tensor = aten::_grad_sum_to_size(%10, %7)
- %12 : Tensor = aten::mul(%2, %3)
- %grad_self.11 : Tensor = aten::_grad_sum_to_size(%12, %8)
- %14 : Tensor = aten::mul(%grad_self.11, %1)
- %grad_other.13 : Tensor = aten::_grad_sum_to_size(%14, %7)
- %16 : Tensor = aten::add(%grad_other.9, %grad_other.13, %9)
- %17 : Tensor = aten::mul(%grad_self.11, %0)
- %grad_self.13 : Tensor = aten::_grad_sum_to_size(%17, %6)
- return (%grad_self.13, %16)
+++ /dev/null
-graph(%x : Float(*, *),
- %hx : Float(*, *),
- %cx : Float(*, *),
- %w_ih : Float(*, *),
- %w_hh : Float(*, *),
- %alpha : Float(*),
- %beta_i : Float(*),
- %beta_h : Float(*),
- %bias : Float(*)):
- %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %bias, %beta_h, %beta_i, %alpha, %hx, %w_hh, %x, %w_ih)
- %11 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
- return (%11)
-with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *),
- %1 : Float(*),
- %2 : Float(*),
- %3 : Float(*),
- %4 : Float(*),
- %5 : Float(*, *),
- %6 : Float(*, *),
- %7 : Float(*, *),
- %8 : Float(*, *)):
- %9 : Float(*, *) = aten::t(%8)
- %Wx.1 : Float(*, *) = aten::mm(%7, %9)
- %11 : Float(*, *) = aten::t(%6)
- %Uz.1 : Float(*, *) = aten::mm(%5, %11)
- %13 : Float(*, *) = aten::mul(%4, %Wx.1)
- %14 : int[] = aten::size(%1)
- %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
- %16 : Tensor[] = aten::broadcast_tensors(%15)
- %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16)
- %23 : int[] = aten::size(%3)
- %24 : int[] = aten::size(%Wx.1)
- %25 : int[] = prim::BroadcastSizes(%23, %24)
- %26 : int[] = aten::size(%13)
- %27 : int[] = aten::size(%Uz.1)
- %28 : int[] = prim::BroadcastSizes(%26, %27)
- %29 : int[] = aten::size(%2)
- %30 : int[] = prim::BroadcastSizes(%29, %27)
- %31 : int[] = prim::BroadcastSizes(%28, %25)
- %32 : int[] = prim::BroadcastSizes(%31, %30)
- %hy : Float(*, *), %34 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17)
- %40 : int[] = aten::size(%0)
- %41 : int[] = aten::size(%cellgate.1)
- %42 : int[] = aten::size(%forgetgate.1)
- %43 : int[] = aten::size(%ingate.1)
- %44 : int[] = prim::BroadcastSizes(%42, %40)
- %45 : int[] = prim::BroadcastSizes(%43, %41)
- return (%hy, %cy, %Wx.1, %Uz.1, %13, %28, %25, %31, %30, %32, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %44, %45, %34)
-with prim::FusionGroup_0 = graph(%0 : Float(*, *),
- %1 : Tensor,
- %2 : Tensor,
- %3 : Tensor,
- %4 : Tensor,
- %5 : Tensor,
- %6 : Tensor):
- %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%6)
- %11 : Float(*, *), %12 : Float(*, *), %13 : Float(*, *), %14 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%5)
- %15 : Float(*, *), %16 : Float(*, *), %17 : Float(*, *), %18 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
- %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
- %23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
- %27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
- %31 : int = prim::Constant[value=1]()
- %32 : Float(*, *) = aten::mul(%23, %27)
- %33 : Float(*, *) = aten::mul(%24, %28)
- %34 : Float(*, *) = aten::mul(%25, %29)
- %35 : Float(*, *) = aten::mul(%26, %30)
- %36 : Float(*, *) = aten::mul(%19, %15)
- %37 : Float(*, *) = aten::mul(%20, %16)
- %38 : Float(*, *) = aten::mul(%21, %17)
- %39 : Float(*, *) = aten::mul(%22, %18)
- %40 : Float(*, *) = aten::mul(%11, %15)
- %41 : Float(*, *) = aten::mul(%12, %16)
- %42 : Float(*, *) = aten::mul(%13, %17)
- %43 : Float(*, *) = aten::mul(%14, %18)
- %44 : Float(*, *) = aten::add(%36, %32, %31)
- %45 : Float(*, *) = aten::add(%37, %33, %31)
- %46 : Float(*, *) = aten::add(%38, %34, %31)
- %47 : Float(*, *) = aten::add(%39, %35, %31)
- %48 : Float(*, *) = aten::add(%44, %40, %31)
- %49 : Float(*, *) = aten::add(%45, %41, %31)
- %50 : Float(*, *) = aten::add(%46, %42, %31)
- %51 : Float(*, *) = aten::add(%47, %43, %31)
- %52 : Float(*, *) = aten::add(%48, %7, %31)
- %53 : Float(*, *) = aten::add(%49, %8, %31)
- %54 : Float(*, *) = aten::add(%50, %9, %31)
- %55 : Float(*, *) = aten::add(%51, %10, %31)
- %ingate.1 : Float(*, *) = aten::sigmoid(%52)
- %forgetgate.1 : Float(*, *) = aten::sigmoid(%53)
- %cellgate.1 : Float(*, *) = aten::tanh(%54)
- %outgate.1 : Float(*, *) = aten::sigmoid(%55)
- %60 : Float(*, *) = aten::mul(%forgetgate.1, %0)
- %61 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
- %cy : Float(*, *) = aten::add(%60, %61, %31)
- %63 : Float(*, *) = aten::tanh(%cy)
- %hy : Float(*, *) = aten::mul(%outgate.1, %63)
- return (%hy, %63, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1)
+++ /dev/null
-graph(%x : Float(),
- %y : Float()):
- %2 : Float() = prim::FusionGroup_0(%y, %x)
- return (%2)
-with prim::FusionGroup_0 = graph(%0 : Float(),
- %1 : Float()):
- %2 : int = prim::Constant[value=1]()
- %3 : int = prim::Constant[value=2]()
- %4 : Float() = aten::mul(%1, %3)
- %5 : Float() = aten::add(%4, %0, %2)
- return (%5)
+++ /dev/null
-graph(%mat : Tensor,
- %mat1 : Tensor,
- %mat2 : Tensor,
- %alpha : Tensor,
- %beta : Tensor):
- %5 : int = prim::Constant[value=1]()
- %6 : float = prim::Constant[value=4.2]()
- %7 : float = prim::Constant[value=2]()
- %8 : Tensor = aten::mm(%mat1, %mat2)
- %9 : int = prim::Constant[value=1]()
- %10 : Tensor = aten::add(%mat, %8, %9)
- %11 : Tensor = aten::mm(%mat1, %mat2)
- %12 : int = prim::Constant[value=1]()
- %13 : Tensor = aten::add(%mat, %11, %12)
- %c : Tensor = aten::addmm(%mat, %mat1, %mat2, %7, %6)
- %15 : int = prim::Int(%alpha)
- %16 : int = prim::Int(%beta)
- %d : Tensor = aten::addmm(%mat, %mat1, %mat2, %16, %15)
- %18 : Tensor = aten::add(%10, %13, %5)
- %19 : Tensor = aten::add(%18, %c, %5)
- %20 : Tensor = aten::add(%19, %d, %5)
- return (%20)
+++ /dev/null
-graph(%x : Double(1),
- %y : Double(1)):
- return (%x)
+++ /dev/null
-graph(%0 : Double(1),
- %1 : Double(1)):
- %2 : Double(1) = aten::type_as(%0, %1)
- return (%2)
+++ /dev/null
-graph(%x : Double(1),
- %y : Double(1)):
- return (%x)
+++ /dev/null
-graph(%x : Double(*, *),
- %y : Double(*, *),
- %c : Double(*, *)):
- %3 : int = prim::Constant[value=1](), scope: AddmmWrapper
- %4 : Double(*, *) = aten::mm(%x, %y), scope: AddmmWrapper
- %5 : Double(*, *) = aten::add(%4, %c, %3), scope: AddmmWrapper
- return (%5)
+++ /dev/null
-ModelProto {
- producer_name: "pytorch"
- domain: ""
- doc_string: ""
- graph:
- GraphProto {
- name: "torch-jit-export"
- inputs: [{name: "0", type:Tensor dims: 3 4},{name: "1", type:Tensor dims: 4 5},{name: "2", type:Tensor dims: 3 5}]
- outputs: [{name: "3", type:Tensor dims: 3 5}]
- initializers: []
- nodes: [
- Node {type: "Gemm", inputs: [0,1,2], outputs: [3], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1}]}
- ]
- }
- opset_import: [OperatorSetIdProto { domain: }],
-}
+++ /dev/null
-graph(%x : Double(3, 4)):
- %1 : Double(3, 4) = aten::neg(%x)
- %2 : Long() = prim::Constant[value={1}]()
- %3 : int = prim::Constant[value=1]()
- %4 : Double(3, 4) = aten::add(%1, %2, %3)
- return (%4)
+++ /dev/null
-graph(%x : Tensor):
- %1 : int = prim::Constant[value=1]()
- %2 : Tensor[] = prim::ListConstruct(%x, %x)
- %3 : Tensor = aten::cat(%2, %1)
- return (%3)
-graph(%x : Tensor):
- %1 : int = prim::Constant[value=1]()
- %2 : Tensor[] = prim::ListConstruct()
- %3 : Tensor = aten::cat(%2, %1)
- return (%3)
-graph(%x : Tensor):
- %1 : int = prim::Constant[value=1]()
- %2 : Tensor[] = prim::ListConstruct(%x)
- %3 : Tensor = aten::cat(%2, %1)
- return (%3)
+++ /dev/null
-graph(%x : Tensor):
- %1 : bool = prim::Constant[value=0]()
- %2 : int = prim::Constant[value=4]()
- %3 : int[] = prim::ListConstruct(%2)
- %4 : Tensor = aten::sum(%x, %3, %1)
- return (%4)
+++ /dev/null
-graph(%x : Double(*, *, *, *, *)):
- %1 : bool = prim::Constant[value=0]()
- %2 : int = prim::Constant[value=4]()
- %3 : int[] = prim::ListConstruct(%2)
- %4 : Tensor = aten::sum(%x, %3, %1)
- return (%4)
+++ /dev/null
-graph(%x : Double(2, 3, 4)):
- %1 : Double(2, 3, 4) = aten::contiguous(%x)
- return (%1)
return x.type_as(y)
tf = torch.jit.trace(f, (a, b))
+ FileCheck().check("type_as").run(str(tf.graph))
self.run_pass('peephole', tf.graph)
- self.assertExpectedGraph(tf.graph)
+ FileCheck().check_not("type_as").run(str(tf.graph))
tf2 = torch.jit.trace(f, (a, c))
s = str(tf2.graph)
self.run_pass('peephole', tf2.graph)
self.assertEqual(s, str(trace.graph))
trace = torch.jit.trace(f, (b, c))
self.run_pass('peephole', trace.graph)
- self.assertExpectedGraph(trace.graph, subname="same_device")
+ self.assertTrue(len(list(trace.graph.nodes())) == 0)
def test_index(self):
x = torch.tensor([0.4], requires_grad=True)
self.assertEqual(t_node.attributeNames(), ["a"])
g2.appendNode(t_node)
self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
- self.assertExpected(str(g2))
+ for node in g.nodes():
+ self.assertTrue(g2.findNode(node.kind()) is not None)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
fn(y)
def test_decompose_addmm(self):
- @torch.jit.script
- def addmm(mat, mat1, mat2, alpha, beta):
- a = mat.addmm(mat1, mat2)
- b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
- c = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
- d = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
+ def does_decompose():
+ @torch.jit.script
+ def addmm(mat, mat1, mat2, alpha, beta):
+ a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
+ b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
- return a + b + c + d
+ return a + b
- mat = torch.randn(2, 2)
- mat1 = torch.randn(2, 4)
- mat2 = torch.randn(4, 2)
- alpha = torch.FloatTensor([123.0])
- beta = torch.FloatTensor([321.0])
+ mat = torch.randn(2, 2)
+ mat1 = torch.randn(2, 4)
+ mat2 = torch.randn(4, 2)
+ alpha = torch.FloatTensor([123.0])
+ beta = torch.FloatTensor([321.0])
- out_ref = addmm(mat, mat1, mat2, alpha, beta)
- self.run_pass('canonicalize_ops', addmm.graph)
- out_test = addmm(mat, mat1, mat2, alpha, beta)
- self.assertEqual(out_ref, out_test)
- self.assertExpected(canonical(addmm.graph))
+ out_ref = addmm(mat, mat1, mat2, alpha, beta)
+ self.run_pass('canonicalize_ops', addmm.graph)
+ out_test = addmm(mat, mat1, mat2, alpha, beta)
+ self.assertEqual(out_ref, out_test)
+ FileCheck().check_not("addmm").run(str(addmm.graph))
+
+ def doesnt_decompose():
+ @torch.jit.script
+ def addmm(mat, mat1, mat2, alpha, beta):
+ a = mat.addmm(mat1, mat2)
+ b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
+
+ orig = str(addm.graph)
+ self.run_pass('canonicalize_ops', addmm.graph)
+ self.assertTrue(orig == str(addmm.graph))
def test_index_put(self):
ten = torch.zeros(3, 3)
self.assertEqual(res, res_batch.examples())
script_if = torch.jit.script(single_if)
- graph = torch.to_batch_graph(script_if.graph)
- self.assertExpected(canonical(graph))
+ torch.to_batch_graph(script_if.graph)
def test_if_else_with_scalar(self):
def single_if(a, b):
self.assertEqual(res, res_batch.examples())
script_if = torch.jit.script(single_if)
- graph = torch.to_batch_graph(script_if.graph)
- self.assertExpected(canonical(graph))
+ torch.to_batch_graph(script_if.graph)
def test_if_noelse(self):
def single_if(a, b):
self.assertEqual(res, res_batch.examples())
script_if = torch.jit.script(single_if)
- graph = torch.to_batch_graph(script_if.graph)
- self.assertExpected(canonical(graph))
+ torch.to_batch_graph(script_if.graph)
def test_if_noelse_with_scalar(self):
def single_if(a, b):
self.assertEqual(res, res_batch.examples())
script_if = torch.jit.script(single_if)
- graph = torch.to_batch_graph(script_if.graph)
- self.assertExpected(canonical(graph))
+ torch.to_batch_graph(script_if.graph)
def test_while(self):
def single_while(a, b):
self.assertEqual(res, res_batch.examples())
script_while = torch.jit.script(single_while)
- graph = torch.to_batch_graph(script_while.graph)
- self.assertExpected(canonical(graph))
+ torch.to_batch_graph(script_while.graph)
def test_for(self):
def single_for(x, y):
self.assertEqual(res, res_batch.examples())
script_for = torch.jit.script(single_for)
- graph = torch.to_batch_graph(script_for.graph)
- self.assertExpected(canonical(graph))
+ torch.to_batch_graph(script_for.graph)
def test_lstm(self):
def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
def func2(x):
return x.sum(dim=4)
- self.assertExpected(canonical(func.graph), subname='1')
# test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
+ self.run_pass('constant_propagation', func.graph)
+ self.run_pass('constant_propagation', func2.graph)
+ torch._C._jit_pass_shape_analysis(
+ func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
torch._C._jit_pass_shape_analysis(
func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
- self.assertExpected(canonical(func2.graph), subname='2')
+ self.assertTrue(func.graph.findNode("aten::sum").output().type().kind()
+ == "DimensionedTensorType")
+ self.assertTrue(func2.graph.findNode("aten::sum").output().type().kind()
+ == "DimensionedTensorType")
def test_cat(self):
@torch.jit.script
def foo3(x):
return torch.cat([x], dim=1)
- self.assertExpected(
- canonical(foo.graph) +
- canonical(foo2.graph) +
- canonical(foo3.graph))
+ for g in [foo.graph, foo2.graph, foo3.graph]:
+ FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g))
def test_list_literal(self):
def reassign():
# The neg op in the python function should be properly inlined to the
# graph
- self.assertExpected(canonical(traced_fn.graph))
+ FileCheck().check("aten::neg").run(str(traced_fn.graph))
def test_call_python_mod_from_tracing_fn(self):
class PythonMod(torch.nn.Module):
self.checkScript(code, (101,), name='elif_test', outputs=3028)
- def test_addmm_fusion(self):
- class AddmmWrapper(torch.nn.Module):
- def forward(self, x, y, c):
- return torch.mm(x, y) + c
-
- # Test addmm fusion is disabled for normal Jit
- x, y, c = torch.rand(3, 4), torch.rand(4, 5), torch.rand(3, 5)
- f = io.BytesIO()
- pretty = torch.onnx.export_to_pretty_string(AddmmWrapper(), (x, y, c), f)
- self.assertExpected(pretty, 'onnx')
-
- jit_trace = torch.jit.trace(AddmmWrapper(), (x, y, c))
- ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c)
- self.assertExpectedGraph(ge_graph, 'jit')
-
def test_pyop_exception_message(self):
class Foo(torch.jit.ScriptModule):
def __init__(self):
x = torch.rand(2, 3, 4)
traced = torch.jit.trace(foo, (x,))
- self.assertExpectedGraph(traced.graph)
+ FileCheck().check("aten::contiguous").run(str(traced.graph))
def test_weak_module(self):
torch.randn(4, dtype=torch.float, device='cuda'),
]
ge = self.checkTrace(scaleshift, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs))
+ self.assertAllFused(ge.graph_for(*inputs))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
ge = self.checkScript(fn, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs))
+ graph = ge.graph_for(*inputs)
+ self.assertAllFused(graph)
+ FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph))
@staticmethod
def _test_chunk_correctness(self, device='cpu'):
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
ge = self.checkTrace(f, (x, y))
- self.assertExpectedGraph(ge.graph_for(x, y))
+ graph = ge.graph_for(x, y)
+ FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_0') \
+ .check_count('ConstantChunk', 2, exactly=True).run(str(graph))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
]
ge = self.checkScript(fn, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs))
+ self.assertAllFused(ge.graph_for(*inputs))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
return torch.cat((hx + cx, hx * cx))
ge = self.checkTrace(foo, (hx, cx))
- self.assertExpectedGraph(ge.graph_for(hx, cx))
+ graph = ge.graph_for(hx, cx)
+ self.assertAllFused(graph)
+ FileCheck().check("FusedConcat").check_next("return").run(str(graph))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
y = torch.randn(2, 2, dtype=torch.float, device='cuda')
z = torch.randn(4, 2, dtype=torch.float, device='cuda')
ge = self.checkTrace(fn, (x, y, z))
- self.assertExpectedGraph(ge.graph_for(x, y, z))
+ graph = ge.graph_for(x, y, z)
+ self.assertAllFused(graph, except_for={'aten::add'})
+ FileCheck().check("FusedConcat").check_next("return").run(str(graph))
@staticmethod
def fn_test_exp(x, y):
forward_graph = module.graph_for(*inputs)
self.assertGraphContainsExactly(
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
- self.assertExpectedGraph(forward_graph, subname='forward')
+ self.assertTrue(len(list(forward_graph.nodes())) == 2)
+ # Everything is differentiable but TupleConstruct return
+ FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
+ .check_next("return").run(str(forward_graph))
hy, cy = module(*inputs)
(hy + cy).sum().backward()
- self.assertExpectedGraph(backward_graph(module), subname='backward')
+ backward = backward_graph(module)
+ FileCheck().check("FusionGroup_0").check_next("FusionGroup_1") \
+ .check_not("FusionGroup_2").run(str(backward))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_lstm_concat_cuda(self):
inputs = get_lstm_inputs('cuda')
ge = self.checkTrace(LSTMCellC, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs))
+ graph = ge.graph_for(*inputs)
+ FileCheck().check("FusedConcat").check_next("return").run(str(graph))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_lstm_traced_cuda(self):
inputs = get_lstm_inputs('cuda')
ge = self.checkTrace(LSTMCellF, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs))
+ graph = ge.graph_for(*inputs)
+ FileCheck().check_not("Chunk").check_not("aten::add").check_not("aten::sigmoid") \
+ .check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \
+ .check_next("return").check_not("FusionGroup_1").run(str(graph))
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
inputs = get_lstm_inputs('cpu')
try:
ge = self.checkTrace(LSTMCellF, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs))
+ graph = ge.graph_for(*inputs)
+ FileCheck.check("FusionGroup").run(str(graph))
except RuntimeError as e:
if 'Failed to compile' in e.args[0]:
warnings.warn('CPU fuser test has failed! This is not a hard failure, '
forward_graph = module.graph_for(*inputs)
self.assertGraphContainsExactly(
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
- self.assertExpectedGraph(forward_graph, subname='forward')
-
+ FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
+ .check_next("return").check("FusionGroup").run(str(forward_graph))
hy, cy = module(*inputs)
(hy + cy).sum().backward()
- self.assertExpectedGraph(backward_graph(module), subname='backward')
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
y = torch.tensor(1, dtype=torch.float, device='cpu')
ge = self.checkScript(fn, (x, y))
- self.assertExpectedGraph(ge.graph_for(x, y))
+ self.assertAllFused(ge.graph_for(x, y))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
TORCH_API FileCheck* check(const std::string& str);
// Checks that the string does not occur between the previous match and next
- // match Consecutive check_nots test against the same previous match and next
+ // match. Consecutive check_nots test against the same previous match and next
// match
TORCH_API FileCheck* check_not(const std::string& str);
// previous match
TORCH_API FileCheck* check_next(const std::string& str);
- // Checks that the string occurs count number of times. If exactly is true,
- // checks that there are exactly count many matches
+ // Checks that the string occurs count number of times, starting at the end
+ // of the previous match. If exactly is true, checks that there are exactly
+ // count many matches
TORCH_API FileCheck* check_count(
const std::string& str,
size_t count,
bool exactly = false);
// A series of consecutive check_dags get turned into a group of checks
- // which can appear in any order relative to each other.
+ // which can appear in any order relative to each other. The checks begin
+ // at the end of the previous match, and the match for the check_dag group
+ // is the minimum match of all individual checks to the maximum match of all
+ // individual checks.
TORCH_API FileCheck* check_dag(const std::string& str);
// reset checks