with prim::FusionGroup_0 = graph(%0 : Float(*)
%1 : Float(*, *)
%2 : Float(*)) {
- %3 : Float(*, *) = aten::mul(%1, %2)
- %4 : int = prim::Constant[value=1]()
- %5 : Float(*, *) = aten::add(%3, %0, %4)
+ %3 : int = prim::Constant[value=1]()
+ %4 : Float(*, *) = aten::mul(%1, %2)
+ %5 : Float(*, *) = aten::add(%4, %0, %3)
return (%5);
}
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)) {
%1 : Float(*, *), %2 : Float(*, *), %3 : Float(*, *) = prim::ConstantChunk[chunks=3, dim=1](%0)
- %4 : Float(*, *) = aten::mul(%1, %2)
- %5 : int = prim::Constant[value=1]()
- %6 : Float(*, *) = aten::add(%4, %3, %5)
+ %4 : int = prim::Constant[value=1]()
+ %5 : Float(*, *) = aten::mul(%1, %2)
+ %6 : Float(*, *) = aten::add(%5, %3, %4)
return (%6);
}
%4 : Float(*, *), %5 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%0)
%6 : int = prim::Constant[value=1]()
%7 : Float(*, *) = aten::add(%2, %4, %6)
- %8 : int = prim::Constant[value=1]()
- %9 : Float(*, *) = aten::add(%3, %5, %8)
- %10 : Float(*, *) = aten::mul(%7, %9)
- return (%10);
+ %8 : Float(*, *) = aten::add(%3, %5, %6)
+ %9 : Float(*, *) = aten::mul(%7, %8)
+ return (%9);
}
%9 : Float(*, *, *), %10 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=0](%1)
%11 : int = prim::Constant[value=1]()
%12 : Float(*, *, *) = aten::add(%0, %6, %11)
- %13 : int = prim::Constant[value=1]()
- %14 : Float(*, *, *) = aten::add(%12, %7, %13)
- %15 : int = prim::Constant[value=1]()
- %16 : Float(*, *, *) = aten::add(%14, %8, %15)
- %17 : int = prim::Constant[value=1]()
- %18 : Float(*, *, *) = aten::add(%16, %9, %17)
- %19 : int = prim::Constant[value=1]()
- %20 : Float(*, *, *) = aten::add(%18, %10, %19)
- %21 : int = prim::Constant[value=1]()
- %22 : Float(*, *, *) = aten::add(%20, %4, %21)
- %23 : int = prim::Constant[value=1]()
- %24 : Float(*, *, *) = aten::add(%22, %5, %23)
- return (%24);
+ %13 : Float(*, *, *) = aten::add(%12, %7, %11)
+ %14 : Float(*, *, *) = aten::add(%13, %8, %11)
+ %15 : Float(*, *, *) = aten::add(%14, %9, %11)
+ %16 : Float(*, *, *) = aten::add(%15, %10, %11)
+ %17 : Float(*, *, *) = aten::add(%16, %4, %11)
+ %18 : Float(*, *, *) = aten::add(%17, %5, %11)
+ return (%18);
}
%1 : Float(*, *)) {
%2 : int = prim::Constant[value=1]()
%x1 : Float(*, *) = aten::add(%0, %1, %2)
- %4 : int = prim::Constant[value=1]()
- %y1 : Float(*, *) = aten::sub(%0, %1, %4)
+ %y1 : Float(*, *) = aten::sub(%0, %1, %2)
%w : Float(*, *) = prim::FusedConcat[dim=0](%x1, %y1)
return (%w);
}
%17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%21 : int = prim::Constant[value=1]()
%22 : Float(*, *) = aten::add(%13, %17, %21)
- %23 : int = prim::Constant[value=1]()
- %24 : Float(*, *) = aten::add(%14, %18, %23)
- %25 : int = prim::Constant[value=1]()
- %26 : Float(*, *) = aten::add(%15, %19, %25)
- %27 : int = prim::Constant[value=1]()
- %28 : Float(*, *) = aten::add(%16, %20, %27)
- %29 : int = prim::Constant[value=1]()
- %30 : Float(*, *) = aten::add(%5, %9, %29)
- %31 : int = prim::Constant[value=1]()
- %32 : Float(*, *) = aten::add(%6, %10, %31)
- %33 : int = prim::Constant[value=1]()
- %34 : Float(*, *) = aten::add(%7, %11, %33)
- %35 : int = prim::Constant[value=1]()
- %36 : Float(*, *) = aten::add(%8, %12, %35)
- %37 : int = prim::Constant[value=1]()
- %38 : Float(*, *) = aten::add(%30, %22, %37)
- %39 : int = prim::Constant[value=1]()
- %40 : Float(*, *) = aten::add(%32, %24, %39)
- %41 : int = prim::Constant[value=1]()
- %42 : Float(*, *) = aten::add(%34, %26, %41)
- %43 : int = prim::Constant[value=1]()
- %44 : Float(*, *) = aten::add(%36, %28, %43)
- %ingate : Float(*, *) = aten::sigmoid(%38)
- %forgetgate : Float(*, *) = aten::sigmoid(%40)
- %cellgate : Float(*, *) = aten::tanh(%42)
- %outgate : Float(*, *) = aten::sigmoid(%44)
- %49 : Float(*, *) = aten::mul(%forgetgate, %0)
- %50 : Float(*, *) = aten::mul(%ingate, %cellgate)
- %51 : int = prim::Constant[value=1]()
- %cy : Float(*, *) = aten::add(%49, %50, %51)
- %53 : Float(*, *) = aten::tanh(%cy)
- %hy : Float(*, *) = aten::mul(%outgate, %53)
- %55 : Float(*, *) = prim::FusedConcat[dim=0](%hy, %cy)
- return (%55);
+ %23 : Float(*, *) = aten::add(%14, %18, %21)
+ %24 : Float(*, *) = aten::add(%15, %19, %21)
+ %25 : Float(*, *) = aten::add(%16, %20, %21)
+ %26 : Float(*, *) = aten::add(%5, %9, %21)
+ %27 : Float(*, *) = aten::add(%6, %10, %21)
+ %28 : Float(*, *) = aten::add(%7, %11, %21)
+ %29 : Float(*, *) = aten::add(%8, %12, %21)
+ %30 : Float(*, *) = aten::add(%26, %22, %21)
+ %31 : Float(*, *) = aten::add(%27, %23, %21)
+ %32 : Float(*, *) = aten::add(%28, %24, %21)
+ %33 : Float(*, *) = aten::add(%29, %25, %21)
+ %ingate : Float(*, *) = aten::sigmoid(%30)
+ %forgetgate : Float(*, *) = aten::sigmoid(%31)
+ %cellgate : Float(*, *) = aten::tanh(%32)
+ %outgate : Float(*, *) = aten::sigmoid(%33)
+ %38 : Float(*, *) = aten::mul(%forgetgate, %0)
+ %39 : Float(*, *) = aten::mul(%ingate, %cellgate)
+ %cy : Float(*, *) = aten::add(%38, %39, %21)
+ %41 : Float(*, *) = aten::tanh(%cy)
+ %hy : Float(*, *) = aten::mul(%outgate, %41)
+ %43 : Float(*, *) = prim::FusedConcat[dim=0](%hy, %cy)
+ return (%43);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%2 : Float(*, *)) {
- %3 : Float(*, *) = aten::mul(%2, %2)
- %4 : Float(*, *) = aten::neg(%3)
- %5 : int = prim::Constant[value=1]()
- %6 : Float(*, *) = aten::add(%4, %5, %5)
+ %3 : int = prim::Constant[value=1]()
+ %4 : Float(*, *) = aten::mul(%2, %2)
+ %5 : Float(*, *) = aten::neg(%4)
+ %6 : Float(*, *) = aten::add(%5, %3, %3)
%7 : Tensor = aten::mul(%1, %6)
- %8 : int = prim::Constant[value=1]()
- %9 : Tensor = aten::add(%0, %7, %8)
- return (%9);
+ %8 : Tensor = aten::add(%0, %7, %3)
+ return (%8);
}
with prim::FusionGroup_1 = graph(%0 : Tensor
%1 : Float(*, *)
%5 : Float(*, *)
%6 : Tensor
%7 : Float(*, *)) {
- %8 : Float(*, *) = aten::neg(%7)
- %9 : int = prim::Constant[value=1]()
- %10 : Float(*, *) = aten::add(%8, %9, %9)
+ %8 : int = prim::Constant[value=1]()
+ %9 : Float(*, *) = aten::neg(%7)
+ %10 : Float(*, *) = aten::add(%9, %8, %8)
%11 : Float(*, *) = aten::mul(%10, %7)
%12 : Tensor = aten::mul(%11, %6)
%13 : Float(*, *) = aten::mul(%5, %5)
%14 : Float(*, *) = aten::neg(%13)
- %15 : int = prim::Constant[value=1]()
- %16 : Float(*, *) = aten::add(%14, %15, %15)
- %17 : Tensor = aten::mul(%4, %16)
- %18 : Float(*, *) = aten::neg(%3)
- %19 : int = prim::Constant[value=1]()
- %20 : Float(*, *) = aten::add(%18, %19, %19)
- %21 : Float(*, *) = aten::mul(%20, %3)
- %22 : Tensor = aten::mul(%21, %2)
- %23 : Float(*, *) = aten::neg(%1)
- %24 : int = prim::Constant[value=1]()
- %25 : Float(*, *) = aten::add(%23, %24, %24)
- %26 : Float(*, *) = aten::mul(%25, %1)
- %27 : Tensor = aten::mul(%26, %0)
- %28 : Tensor = prim::FusedConcat[dim=1](%27, %22, %17, %12)
- return (%28);
+ %15 : Float(*, *) = aten::add(%14, %8, %8)
+ %16 : Tensor = aten::mul(%4, %15)
+ %17 : Float(*, *) = aten::neg(%3)
+ %18 : Float(*, *) = aten::add(%17, %8, %8)
+ %19 : Float(*, *) = aten::mul(%18, %3)
+ %20 : Tensor = aten::mul(%19, %2)
+ %21 : Float(*, *) = aten::neg(%1)
+ %22 : Float(*, *) = aten::add(%21, %8, %8)
+ %23 : Float(*, *) = aten::mul(%22, %1)
+ %24 : Tensor = aten::mul(%23, %0)
+ %25 : Tensor = prim::FusedConcat[dim=1](%24, %20, %16, %12)
+ return (%25);
}
%17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%21 : int = prim::Constant[value=1]()
%22 : Float(*, *) = aten::add(%13, %17, %21)
- %23 : int = prim::Constant[value=1]()
- %24 : Float(*, *) = aten::add(%14, %18, %23)
- %25 : int = prim::Constant[value=1]()
- %26 : Float(*, *) = aten::add(%15, %19, %25)
- %27 : int = prim::Constant[value=1]()
- %28 : Float(*, *) = aten::add(%16, %20, %27)
- %29 : int = prim::Constant[value=1]()
- %30 : Float(*, *) = aten::add(%22, %9, %29)
- %31 : int = prim::Constant[value=1]()
- %32 : Float(*, *) = aten::add(%24, %10, %31)
- %33 : int = prim::Constant[value=1]()
- %34 : Float(*, *) = aten::add(%26, %11, %33)
- %35 : int = prim::Constant[value=1]()
- %36 : Float(*, *) = aten::add(%28, %12, %35)
- %37 : int = prim::Constant[value=1]()
- %38 : Float(*, *) = aten::add(%30, %5, %37)
- %39 : int = prim::Constant[value=1]()
- %40 : Float(*, *) = aten::add(%32, %6, %39)
- %41 : int = prim::Constant[value=1]()
- %42 : Float(*, *) = aten::add(%34, %7, %41)
- %43 : int = prim::Constant[value=1]()
- %44 : Float(*, *) = aten::add(%36, %8, %43)
- %ingate.1 : Float(*, *) = aten::sigmoid(%38)
- %forgetgate.1 : Float(*, *) = aten::sigmoid(%40)
- %cellgate.1 : Float(*, *) = aten::tanh(%42)
- %outgate.1 : Float(*, *) = aten::sigmoid(%44)
- %49 : Float(*, *) = aten::mul(%forgetgate.1, %0)
- %50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
- %51 : int = prim::Constant[value=1]()
- %cy : Float(*, *) = aten::add(%49, %50, %51)
- %53 : Float(*, *) = aten::tanh(%cy)
- %hy : Float(*, *) = aten::mul(%outgate.1, %53)
- return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+ %23 : Float(*, *) = aten::add(%14, %18, %21)
+ %24 : Float(*, *) = aten::add(%15, %19, %21)
+ %25 : Float(*, *) = aten::add(%16, %20, %21)
+ %26 : Float(*, *) = aten::add(%22, %9, %21)
+ %27 : Float(*, *) = aten::add(%23, %10, %21)
+ %28 : Float(*, *) = aten::add(%24, %11, %21)
+ %29 : Float(*, *) = aten::add(%25, %12, %21)
+ %30 : Float(*, *) = aten::add(%26, %5, %21)
+ %31 : Float(*, *) = aten::add(%27, %6, %21)
+ %32 : Float(*, *) = aten::add(%28, %7, %21)
+ %33 : Float(*, *) = aten::add(%29, %8, %21)
+ %ingate.1 : Float(*, *) = aten::sigmoid(%30)
+ %forgetgate.1 : Float(*, *) = aten::sigmoid(%31)
+ %cellgate.1 : Float(*, *) = aten::tanh(%32)
+ %outgate.1 : Float(*, *) = aten::sigmoid(%33)
+ %38 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+ %39 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+ %cy : Float(*, *) = aten::add(%38, %39, %21)
+ %41 : Float(*, *) = aten::tanh(%cy)
+ %hy : Float(*, *) = aten::mul(%outgate.1, %41)
+ return (%hy, %41, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
}
%17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%21 : int = prim::Constant[value=1]()
%22 : Float(*, *) = aten::add(%13, %17, %21)
- %23 : int = prim::Constant[value=1]()
- %24 : Float(*, *) = aten::add(%14, %18, %23)
- %25 : int = prim::Constant[value=1]()
- %26 : Float(*, *) = aten::add(%15, %19, %25)
- %27 : int = prim::Constant[value=1]()
- %28 : Float(*, *) = aten::add(%16, %20, %27)
- %29 : int = prim::Constant[value=1]()
- %30 : Float(*, *) = aten::add(%5, %9, %29)
- %31 : int = prim::Constant[value=1]()
- %32 : Float(*, *) = aten::add(%6, %10, %31)
- %33 : int = prim::Constant[value=1]()
- %34 : Float(*, *) = aten::add(%7, %11, %33)
- %35 : int = prim::Constant[value=1]()
- %36 : Float(*, *) = aten::add(%8, %12, %35)
- %37 : int = prim::Constant[value=1]()
- %38 : Float(*, *) = aten::add(%30, %22, %37)
- %39 : int = prim::Constant[value=1]()
- %40 : Float(*, *) = aten::add(%32, %24, %39)
- %41 : int = prim::Constant[value=1]()
- %42 : Float(*, *) = aten::add(%34, %26, %41)
- %43 : int = prim::Constant[value=1]()
- %44 : Float(*, *) = aten::add(%36, %28, %43)
- %ingate : Float(*, *) = aten::sigmoid(%38)
- %forgetgate : Float(*, *) = aten::sigmoid(%40)
- %cellgate : Float(*, *) = aten::tanh(%42)
- %outgate : Float(*, *) = aten::sigmoid(%44)
- %49 : Float(*, *) = aten::mul(%forgetgate, %0)
- %50 : Float(*, *) = aten::mul(%ingate, %cellgate)
- %51 : int = prim::Constant[value=1]()
- %cy : Float(*, *) = aten::add(%49, %50, %51)
- %53 : Float(*, *) = aten::tanh(%cy)
- %54 : Float(*, *) = aten::mul(%outgate, %53)
- return (%54, %cy);
+ %23 : Float(*, *) = aten::add(%14, %18, %21)
+ %24 : Float(*, *) = aten::add(%15, %19, %21)
+ %25 : Float(*, *) = aten::add(%16, %20, %21)
+ %26 : Float(*, *) = aten::add(%5, %9, %21)
+ %27 : Float(*, *) = aten::add(%6, %10, %21)
+ %28 : Float(*, *) = aten::add(%7, %11, %21)
+ %29 : Float(*, *) = aten::add(%8, %12, %21)
+ %30 : Float(*, *) = aten::add(%26, %22, %21)
+ %31 : Float(*, *) = aten::add(%27, %23, %21)
+ %32 : Float(*, *) = aten::add(%28, %24, %21)
+ %33 : Float(*, *) = aten::add(%29, %25, %21)
+ %ingate : Float(*, *) = aten::sigmoid(%30)
+ %forgetgate : Float(*, *) = aten::sigmoid(%31)
+ %cellgate : Float(*, *) = aten::tanh(%32)
+ %outgate : Float(*, *) = aten::sigmoid(%33)
+ %38 : Float(*, *) = aten::mul(%forgetgate, %0)
+ %39 : Float(*, *) = aten::mul(%ingate, %cellgate)
+ %cy : Float(*, *) = aten::add(%38, %39, %21)
+ %41 : Float(*, *) = aten::tanh(%cy)
+ %42 : Float(*, *) = aten::mul(%outgate, %41)
+ return (%42, %cy);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%2 : Float(*, *)) {
- %3 : Float(*, *) = aten::mul(%2, %2)
- %4 : Float(*, *) = aten::neg(%3)
- %5 : int = prim::Constant[value=1]()
- %6 : Float(*, *) = aten::add(%4, %5, %5)
+ %3 : int = prim::Constant[value=1]()
+ %4 : Float(*, *) = aten::mul(%2, %2)
+ %5 : Float(*, *) = aten::neg(%4)
+ %6 : Float(*, *) = aten::add(%5, %3, %3)
%7 : Tensor = aten::mul(%1, %6)
- %8 : int = prim::Constant[value=1]()
- %9 : Tensor = aten::add(%0, %7, %8)
- return (%9);
+ %8 : Tensor = aten::add(%0, %7, %3)
+ return (%8);
}
with prim::FusionGroup_1 = graph(%0 : Tensor
%1 : Float(*, *)
%5 : Float(*, *)
%6 : Tensor
%7 : Float(*, *)) {
- %8 : Float(*, *) = aten::neg(%7)
- %9 : int = prim::Constant[value=1]()
- %10 : Float(*, *) = aten::add(%8, %9, %9)
+ %8 : int = prim::Constant[value=1]()
+ %9 : Float(*, *) = aten::neg(%7)
+ %10 : Float(*, *) = aten::add(%9, %8, %8)
%11 : Float(*, *) = aten::mul(%10, %7)
%12 : Tensor = aten::mul(%11, %6)
%13 : Float(*, *) = aten::mul(%5, %5)
%14 : Float(*, *) = aten::neg(%13)
- %15 : int = prim::Constant[value=1]()
- %16 : Float(*, *) = aten::add(%14, %15, %15)
- %17 : Tensor = aten::mul(%4, %16)
- %18 : Float(*, *) = aten::neg(%3)
- %19 : int = prim::Constant[value=1]()
- %20 : Float(*, *) = aten::add(%18, %19, %19)
- %21 : Float(*, *) = aten::mul(%20, %3)
- %22 : Tensor = aten::mul(%21, %2)
- %23 : Float(*, *) = aten::neg(%1)
- %24 : int = prim::Constant[value=1]()
- %25 : Float(*, *) = aten::add(%23, %24, %24)
- %26 : Float(*, *) = aten::mul(%25, %1)
- %27 : Tensor = aten::mul(%26, %0)
- %28 : Tensor = prim::FusedConcat[dim=1](%27, %22, %17, %12)
- return (%28);
+ %15 : Float(*, *) = aten::add(%14, %8, %8)
+ %16 : Tensor = aten::mul(%4, %15)
+ %17 : Float(*, *) = aten::neg(%3)
+ %18 : Float(*, *) = aten::add(%17, %8, %8)
+ %19 : Float(*, *) = aten::mul(%18, %3)
+ %20 : Tensor = aten::mul(%19, %2)
+ %21 : Float(*, *) = aten::neg(%1)
+ %22 : Float(*, *) = aten::add(%21, %8, %8)
+ %23 : Float(*, *) = aten::mul(%22, %1)
+ %24 : Tensor = aten::mul(%23, %0)
+ %25 : Tensor = prim::FusedConcat[dim=1](%24, %20, %16, %12)
+ return (%25);
}
%19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
%23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
%27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
- %31 : Float(*, *) = aten::mul(%23, %27)
- %32 : Float(*, *) = aten::mul(%24, %28)
- %33 : Float(*, *) = aten::mul(%25, %29)
- %34 : Float(*, *) = aten::mul(%26, %30)
- %35 : Float(*, *) = aten::mul(%19, %15)
- %36 : Float(*, *) = aten::mul(%20, %16)
- %37 : Float(*, *) = aten::mul(%21, %17)
- %38 : Float(*, *) = aten::mul(%22, %18)
- %39 : Float(*, *) = aten::mul(%11, %15)
- %40 : Float(*, *) = aten::mul(%12, %16)
- %41 : Float(*, *) = aten::mul(%13, %17)
- %42 : Float(*, *) = aten::mul(%14, %18)
- %43 : int = prim::Constant[value=1]()
- %44 : Float(*, *) = aten::add(%35, %31, %43)
- %45 : int = prim::Constant[value=1]()
- %46 : Float(*, *) = aten::add(%36, %32, %45)
- %47 : int = prim::Constant[value=1]()
- %48 : Float(*, *) = aten::add(%37, %33, %47)
- %49 : int = prim::Constant[value=1]()
- %50 : Float(*, *) = aten::add(%38, %34, %49)
- %51 : int = prim::Constant[value=1]()
- %52 : Float(*, *) = aten::add(%44, %39, %51)
- %53 : int = prim::Constant[value=1]()
- %54 : Float(*, *) = aten::add(%46, %40, %53)
- %55 : int = prim::Constant[value=1]()
- %56 : Float(*, *) = aten::add(%48, %41, %55)
- %57 : int = prim::Constant[value=1]()
- %58 : Float(*, *) = aten::add(%50, %42, %57)
- %59 : int = prim::Constant[value=1]()
- %60 : Float(*, *) = aten::add(%52, %7, %59)
- %61 : int = prim::Constant[value=1]()
- %62 : Float(*, *) = aten::add(%54, %8, %61)
- %63 : int = prim::Constant[value=1]()
- %64 : Float(*, *) = aten::add(%56, %9, %63)
- %65 : int = prim::Constant[value=1]()
- %66 : Float(*, *) = aten::add(%58, %10, %65)
- %ingate.1 : Float(*, *) = aten::sigmoid(%60)
- %forgetgate.1 : Float(*, *) = aten::sigmoid(%62)
- %cellgate.1 : Float(*, *) = aten::tanh(%64)
- %outgate.1 : Float(*, *) = aten::sigmoid(%66)
- %71 : Float(*, *) = aten::mul(%forgetgate.1, %0)
- %72 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
- %73 : int = prim::Constant[value=1]()
- %cy : Float(*, *) = aten::add(%71, %72, %73)
- %75 : Float(*, *) = aten::tanh(%cy)
- %hy : Float(*, *) = aten::mul(%outgate.1, %75)
- return (%hy, %75, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+ %31 : int = prim::Constant[value=1]()
+ %32 : Float(*, *) = aten::mul(%23, %27)
+ %33 : Float(*, *) = aten::mul(%24, %28)
+ %34 : Float(*, *) = aten::mul(%25, %29)
+ %35 : Float(*, *) = aten::mul(%26, %30)
+ %36 : Float(*, *) = aten::mul(%19, %15)
+ %37 : Float(*, *) = aten::mul(%20, %16)
+ %38 : Float(*, *) = aten::mul(%21, %17)
+ %39 : Float(*, *) = aten::mul(%22, %18)
+ %40 : Float(*, *) = aten::mul(%11, %15)
+ %41 : Float(*, *) = aten::mul(%12, %16)
+ %42 : Float(*, *) = aten::mul(%13, %17)
+ %43 : Float(*, *) = aten::mul(%14, %18)
+ %44 : Float(*, *) = aten::add(%36, %32, %31)
+ %45 : Float(*, *) = aten::add(%37, %33, %31)
+ %46 : Float(*, *) = aten::add(%38, %34, %31)
+ %47 : Float(*, *) = aten::add(%39, %35, %31)
+ %48 : Float(*, *) = aten::add(%44, %40, %31)
+ %49 : Float(*, *) = aten::add(%45, %41, %31)
+ %50 : Float(*, *) = aten::add(%46, %42, %31)
+ %51 : Float(*, *) = aten::add(%47, %43, %31)
+ %52 : Float(*, *) = aten::add(%48, %7, %31)
+ %53 : Float(*, *) = aten::add(%49, %8, %31)
+ %54 : Float(*, *) = aten::add(%50, %9, %31)
+ %55 : Float(*, *) = aten::add(%51, %10, %31)
+ %ingate.1 : Float(*, *) = aten::sigmoid(%52)
+ %forgetgate.1 : Float(*, *) = aten::sigmoid(%53)
+ %cellgate.1 : Float(*, *) = aten::tanh(%54)
+ %outgate.1 : Float(*, *) = aten::sigmoid(%55)
+ %60 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+ %61 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+ %cy : Float(*, *) = aten::add(%60, %61, %31)
+ %63 : Float(*, *) = aten::tanh(%cy)
+ %hy : Float(*, *) = aten::mul(%outgate.1, %63)
+ return (%hy, %63, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
}
}
with prim::FusionGroup_0 = graph(%0 : Float()
%1 : Float()) {
- %2 : int = prim::Constant[value=2]()
- %3 : Float() = aten::mul(%1, %2)
- %4 : int = prim::Constant[value=1]()
- %5 : Float() = aten::add(%3, %0, %4)
+ %2 : int = prim::Constant[value=1]()
+ %3 : int = prim::Constant[value=2]()
+ %4 : Float() = aten::mul(%1, %3)
+ %5 : Float() = aten::add(%4, %0, %2)
return (%5);
}
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/script/compiler.h>
return true;
}
- // Can this node produce an _output_ of a fusion group?
- // all Fusable nodes can do this, but additionally Concat, which normally
- // cannot be fused because it is not a simple map, can be put in a fusion
- // group as long as no items in the group read the output of concat
- bool isFusableAsExitNode(Node* node) {
- return isFusableMap(node) || isFusableOnlyAsExitNode(node);
- }
-
- bool isFusableOnlyAsExitNode(Node* node) {
- return isFusableCatNode(node) || node->kind() == prim::FusedConcat;
- }
-
bool calculatesSize(Node* node) {
return node->matches("aten::size(Tensor self) -> int[]");
}
return true;
}
- bool mustRemainAsFusionGroupOutput(Value* producer) {
- if (producer->node()->kind() != prim::FusionGroup) {
- return false;
- }
- auto subgraph = producer->node()->g(attr::Subgraph);
- auto* node = subgraph->outputs().at(producer->offset())->node();
- return isFusableOnlyAsExitNode(node);
- }
-
Graph& getSubgraph(Node* n) {
JIT_ASSERT(n->kind() == prim::FusionGroup);
return *n->g(attr::Subgraph);
// we can move the consumer up into the producer.
// but this requires better handling of merging fusion groups so it is not
// done now
- Node* real_consumer = consumer->kind() == aten::cat
- ? consumer->namedInput(attr::tensors)->node()
- : consumer;
bool shouldFuse = isFusable(producer->node()) &&
// Rearrange nodes such that all uses of producer are after the
// consumer. Fusion will rewrite those later uses to use the version of
// producer generated by the fused blob. In this case, producer becomes
// an output of the fusion group.
producer->node()->moveBeforeTopologicallyValid(
- real_consumer, aliasDb_.value());
+ consumer, aliasDb_.value());
if (!shouldFuse) {
return at::nullopt;
}
auto group = consumer;
- if (consumer->kind() == aten::cat) {
- Graph* graph = consumer->owningGraph();
- Node* list_construct = consumer->namedInput(attr::tensors)->node();
- int64_t dim = consumer->get<int64_t>(attr::dim).value();
-
- Node* fused_cat =
- graph->create(prim::FusedConcat, list_construct->inputs())
- ->i_(attr::dim, dim);
- fused_cat->insertBefore(list_construct);
- fused_cat->output()->copyMetadata(consumer->output());
- consumer->output()->replaceAllUsesWith(fused_cat->output());
-
- // NB: this deletes the fused_cat node from the original graph
- group = createSingletonFusionGroup(fused_cat);
- consumer->destroy();
- if (list_construct->output()->uses().empty()) {
- list_construct->destroy();
- }
- } else if (consumer->kind() != prim::FusionGroup) {
+ if (consumer->kind() != prim::FusionGroup) {
group = createSingletonFusionGroup(consumer);
}
if (producer->node()->matches(
// returns where to continue scanning, and whether any fusion was made
std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
- if (isFusableAsExitNode(consumer)) {
- auto consumer_inputs = consumer->kind() == aten::cat
- ? consumer->namedInput(attr::tensors)->node()->inputs()
- : consumer->inputs();
+ if (isFusable(consumer)) {
// handle inputs in reverse topological order as well...
// otherwise in f(a,a+b) it will appear a is used twice if we consider
// the f-a fusion before the f-(a+b) fusion first.
- auto inputs = sortReverseTopological(consumer_inputs);
+ auto inputs = sortReverseTopological(consumer->inputs());
for (auto producer : inputs) {
- // Don't fuse if producer must come from a FusionGroup exit node
- if (mustRemainAsFusionGroupOutput(producer))
- continue;
if (tryToMoveChunk(consumer, producer)) {
// the chunk before this consumer was re-arranged to allow fusion,
// we scan this consumer again to perform the fusion
aliasDb_ = AliasAnalysis(graph_);
}
+ bool canFuseWithConcat(Value* producer, Node* before_check) {
+ if (!isFusable(producer->node())) {
+ return false;
+ }
+ // NB: it is important that this check happens after isFusable, which checks
+ // that the blocks match, and it's not a special node like prim::Param
+ if (!producer->node()->couldMoveBeforeTopologically(
+ before_check, aliasDb_.value())) {
+ return false;
+ }
+ // Fusion groups can be merged with concat's group if and only if
+ // the value they produce isn't already coming from a concat.
+ if (producer->node()->kind() == prim::FusionGroup) {
+ auto subgraph = producer->node()->g(attr::Subgraph);
+ auto* node = subgraph->outputs().at(producer->offset())->node();
+ return node->kind() != prim::FusedConcat;
+ }
+ return true;
+ }
+
+ Node* createFusedConcat(Node* node) {
+ JIT_ASSERT(node->kind() == aten::cat);
+
+ Graph* graph = node->owningGraph();
+ Node* list_construct = node->namedInput(attr::tensors)->node();
+ int64_t dim = node->get<int64_t>(attr::dim).value();
+
+ Node* fused_cat = graph->create(prim::FusedConcat, list_construct->inputs())
+ ->i_(attr::dim, dim);
+ fused_cat->insertBefore(list_construct);
+ fused_cat->output()->copyMetadata(node->output());
+
+ // NB: this deletes the fused_cat node from the original graph
+ return createSingletonFusionGroup(fused_cat);
+ }
+
+ void fuseConcats() {
+ for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();
+ ++it) {
+ Node* cat = *it;
+ if (!isFusableCatNode(cat)) {
+ continue;
+ }
+ Node* list_construct = cat->namedInput(attr::tensors)->node();
+ Node* fused_cat = createFusedConcat(cat);
+ Value* fused_cat_out = fused_cat->output();
+
+ auto sorted_inputs = sortReverseTopological(fused_cat->inputs());
+ size_t input_idx = 0;
+ bool any_fused = false;
+ while (input_idx < sorted_inputs.size()) {
+ Value* input = sorted_inputs[input_idx++];
+ if (!canFuseWithConcat(input, fused_cat)) {
+ continue;
+ }
+ any_fused = true;
+ auto maybe_group = tryFuse(fused_cat, input);
+ JIT_ASSERT(maybe_group && maybe_group == fused_cat);
+ // We could have destroyed multiple inputs when performing this fusion,
+ // so we have to recompute the list and iterate over it again.
+ sorted_inputs = sortReverseTopological(fused_cat->inputs());
+ input_idx = 0;
+ }
+
+ if (any_fused) {
+ cat->output()->replaceAllUsesWith(fused_cat_out);
+ it.destroyCurrent();
+ if (list_construct->output()->uses().empty()) {
+ list_construct->destroy();
+ }
+ } else {
+ fused_cat->destroy();
+ }
+ }
+ }
+
+ void optimizeFusedGraphs() {
+ for (Node* node : block_->nodes()) {
+ if (node->kind() != prim::FusionGroup) {
+ continue;
+ }
+ auto subgraph = node->g(attr::Subgraph);
+ EliminateDeadCode(subgraph);
+ EliminateCommonSubexpression(subgraph);
+ ConstantPooling(subgraph);
+ }
+ }
+
void run() {
// Run the pass until no changes are made.
// This is neccessary, because the algorithm can miss out on certain fusion
any_changed |= changed;
}
}
+ refreshAliasDb();
+
+ fuseConcats();
+
+ optimizeFusedGraphs();
// The graph fuser can add intermediate prim::BroadcastingChunk nodes.
// Replace them with broadcasts + chunks.