Simplify cat fusion (#15633)
authorAdam Paszke <adam.paszke@gmail.com>
Fri, 11 Jan 2019 18:17:54 +0000 (10:17 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 18:33:42 +0000 (10:33 -0800)
Summary:
That makes that definition of a "fusable node" much simpler,
as we don't need to keep considering whether something has to be an
"exit node" at every step. The fuser now tries to maximize the
pointwise fusions first, and proceeds to prepending chunks and appending
concats only once a fix point is reached.

This patch not only makes the fuser much simpler to reason about,
making it siginifcantly easier to implement features like SumToSize
fusion, to improve performance of derivative graphs.

cc zou3519 mruberry
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15633

Differential Revision: D13575306

Pulled By: zou3519

fbshipit-source-id: 0c55ea61d65d1f1ed3d75a8e1e83bc85a83f3aff

13 files changed:
test/expect/TestFuser.test_broadcast_cuda.expect
test/expect/TestFuser.test_chunk_cuda.expect
test/expect/TestFuser.test_chunk_distributes_cuda.expect
test/expect/TestFuser.test_chunk_multiple_cuda.expect
test/expect/TestFuser.test_concat_invariant_cuda.expect
test/expect/TestFuser.test_lstm_concat_cuda.expect
test/expect/TestFuser.test_lstm_cuda-backward.expect
test/expect/TestFuser.test_lstm_cuda-forward.expect
test/expect/TestFuser.test_lstm_traced_cuda.expect
test/expect/TestFuser.test_milstm_cuda-backward.expect
test/expect/TestFuser.test_milstm_cuda-forward.expect
test/expect/TestFuser.test_scalar.expect
torch/csrc/jit/passes/graph_fuser.cpp

index c3ec90c..47a1c9e 100644 (file)
@@ -7,8 +7,8 @@ graph(%x : Float(*, *)
 with prim::FusionGroup_0 = graph(%0 : Float(*)
       %1 : Float(*, *)
       %2 : Float(*)) {
-  %3 : Float(*, *) = aten::mul(%1, %2)
-  %4 : int = prim::Constant[value=1]()
-  %5 : Float(*, *) = aten::add(%3, %0, %4)
+  %3 : int = prim::Constant[value=1]()
+  %4 : Float(*, *) = aten::mul(%1, %2)
+  %5 : Float(*, *) = aten::add(%4, %0, %3)
   return (%5);
 }
index c22a28b..c2eb07e 100644 (file)
@@ -4,8 +4,8 @@ graph(%x : Float(*, *)) {
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)) {
   %1 : Float(*, *), %2 : Float(*, *), %3 : Float(*, *) = prim::ConstantChunk[chunks=3, dim=1](%0)
-  %4 : Float(*, *) = aten::mul(%1, %2)
-  %5 : int = prim::Constant[value=1]()
-  %6 : Float(*, *) = aten::add(%4, %3, %5)
+  %4 : int = prim::Constant[value=1]()
+  %5 : Float(*, *) = aten::mul(%1, %2)
+  %6 : Float(*, *) = aten::add(%5, %3, %4)
   return (%6);
 }
index b6ecb56..55abd69 100644 (file)
@@ -12,8 +12,7 @@ with prim::FusionGroup_0 = graph(%0 : Tensor
   %4 : Float(*, *), %5 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%0)
   %6 : int = prim::Constant[value=1]()
   %7 : Float(*, *) = aten::add(%2, %4, %6)
-  %8 : int = prim::Constant[value=1]()
-  %9 : Float(*, *) = aten::add(%3, %5, %8)
-  %10 : Float(*, *) = aten::mul(%7, %9)
-  return (%10);
+  %8 : Float(*, *) = aten::add(%3, %5, %6)
+  %9 : Float(*, *) = aten::mul(%7, %8)
+  return (%9);
 }
index 6930d95..5f6283b 100644 (file)
@@ -14,17 +14,11 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *, *)
   %9 : Float(*, *, *), %10 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=0](%1)
   %11 : int = prim::Constant[value=1]()
   %12 : Float(*, *, *) = aten::add(%0, %6, %11)
-  %13 : int = prim::Constant[value=1]()
-  %14 : Float(*, *, *) = aten::add(%12, %7, %13)
-  %15 : int = prim::Constant[value=1]()
-  %16 : Float(*, *, *) = aten::add(%14, %8, %15)
-  %17 : int = prim::Constant[value=1]()
-  %18 : Float(*, *, *) = aten::add(%16, %9, %17)
-  %19 : int = prim::Constant[value=1]()
-  %20 : Float(*, *, *) = aten::add(%18, %10, %19)
-  %21 : int = prim::Constant[value=1]()
-  %22 : Float(*, *, *) = aten::add(%20, %4, %21)
-  %23 : int = prim::Constant[value=1]()
-  %24 : Float(*, *, *) = aten::add(%22, %5, %23)
-  return (%24);
+  %13 : Float(*, *, *) = aten::add(%12, %7, %11)
+  %14 : Float(*, *, *) = aten::add(%13, %8, %11)
+  %15 : Float(*, *, *) = aten::add(%14, %9, %11)
+  %16 : Float(*, *, *) = aten::add(%15, %10, %11)
+  %17 : Float(*, *, *) = aten::add(%16, %4, %11)
+  %18 : Float(*, *, *) = aten::add(%17, %5, %11)
+  return (%18);
 }
index a2e2eb8..6828d16 100644 (file)
@@ -10,8 +10,7 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Float(*, *)) {
   %2 : int = prim::Constant[value=1]()
   %x1 : Float(*, *) = aten::add(%0, %1, %2)
-  %4 : int = prim::Constant[value=1]()
-  %y1 : Float(*, *) = aten::sub(%0, %1, %4)
+  %y1 : Float(*, *) = aten::sub(%0, %1, %2)
   %w : Float(*, *) = prim::FusedConcat[dim=0](%x1, %y1)
   return (%w);
 }
index 8a572f0..772df50 100644 (file)
@@ -26,38 +26,26 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
   %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
   %21 : int = prim::Constant[value=1]()
   %22 : Float(*, *) = aten::add(%13, %17, %21)
-  %23 : int = prim::Constant[value=1]()
-  %24 : Float(*, *) = aten::add(%14, %18, %23)
-  %25 : int = prim::Constant[value=1]()
-  %26 : Float(*, *) = aten::add(%15, %19, %25)
-  %27 : int = prim::Constant[value=1]()
-  %28 : Float(*, *) = aten::add(%16, %20, %27)
-  %29 : int = prim::Constant[value=1]()
-  %30 : Float(*, *) = aten::add(%5, %9, %29)
-  %31 : int = prim::Constant[value=1]()
-  %32 : Float(*, *) = aten::add(%6, %10, %31)
-  %33 : int = prim::Constant[value=1]()
-  %34 : Float(*, *) = aten::add(%7, %11, %33)
-  %35 : int = prim::Constant[value=1]()
-  %36 : Float(*, *) = aten::add(%8, %12, %35)
-  %37 : int = prim::Constant[value=1]()
-  %38 : Float(*, *) = aten::add(%30, %22, %37)
-  %39 : int = prim::Constant[value=1]()
-  %40 : Float(*, *) = aten::add(%32, %24, %39)
-  %41 : int = prim::Constant[value=1]()
-  %42 : Float(*, *) = aten::add(%34, %26, %41)
-  %43 : int = prim::Constant[value=1]()
-  %44 : Float(*, *) = aten::add(%36, %28, %43)
-  %ingate : Float(*, *) = aten::sigmoid(%38)
-  %forgetgate : Float(*, *) = aten::sigmoid(%40)
-  %cellgate : Float(*, *) = aten::tanh(%42)
-  %outgate : Float(*, *) = aten::sigmoid(%44)
-  %49 : Float(*, *) = aten::mul(%forgetgate, %0)
-  %50 : Float(*, *) = aten::mul(%ingate, %cellgate)
-  %51 : int = prim::Constant[value=1]()
-  %cy : Float(*, *) = aten::add(%49, %50, %51)
-  %53 : Float(*, *) = aten::tanh(%cy)
-  %hy : Float(*, *) = aten::mul(%outgate, %53)
-  %55 : Float(*, *) = prim::FusedConcat[dim=0](%hy, %cy)
-  return (%55);
+  %23 : Float(*, *) = aten::add(%14, %18, %21)
+  %24 : Float(*, *) = aten::add(%15, %19, %21)
+  %25 : Float(*, *) = aten::add(%16, %20, %21)
+  %26 : Float(*, *) = aten::add(%5, %9, %21)
+  %27 : Float(*, *) = aten::add(%6, %10, %21)
+  %28 : Float(*, *) = aten::add(%7, %11, %21)
+  %29 : Float(*, *) = aten::add(%8, %12, %21)
+  %30 : Float(*, *) = aten::add(%26, %22, %21)
+  %31 : Float(*, *) = aten::add(%27, %23, %21)
+  %32 : Float(*, *) = aten::add(%28, %24, %21)
+  %33 : Float(*, *) = aten::add(%29, %25, %21)
+  %ingate : Float(*, *) = aten::sigmoid(%30)
+  %forgetgate : Float(*, *) = aten::sigmoid(%31)
+  %cellgate : Float(*, *) = aten::tanh(%32)
+  %outgate : Float(*, *) = aten::sigmoid(%33)
+  %38 : Float(*, *) = aten::mul(%forgetgate, %0)
+  %39 : Float(*, *) = aten::mul(%ingate, %cellgate)
+  %cy : Float(*, *) = aten::add(%38, %39, %21)
+  %41 : Float(*, *) = aten::tanh(%cy)
+  %hy : Float(*, *) = aten::mul(%outgate, %41)
+  %43 : Float(*, *) = prim::FusedConcat[dim=0](%hy, %cy)
+  return (%43);
 }
index de28768..1159f70 100644 (file)
@@ -66,14 +66,13 @@ graph(%0 : Float(*, *)
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
       %2 : Float(*, *)) {
-  %3 : Float(*, *) = aten::mul(%2, %2)
-  %4 : Float(*, *) = aten::neg(%3)
-  %5 : int = prim::Constant[value=1]()
-  %6 : Float(*, *) = aten::add(%4, %5, %5)
+  %3 : int = prim::Constant[value=1]()
+  %4 : Float(*, *) = aten::mul(%2, %2)
+  %5 : Float(*, *) = aten::neg(%4)
+  %6 : Float(*, *) = aten::add(%5, %3, %3)
   %7 : Tensor = aten::mul(%1, %6)
-  %8 : int = prim::Constant[value=1]()
-  %9 : Tensor = aten::add(%0, %7, %8)
-  return (%9);
+  %8 : Tensor = aten::add(%0, %7, %3)
+  return (%8);
 }
 with prim::FusionGroup_1 = graph(%0 : Tensor
       %1 : Float(*, *)
@@ -83,26 +82,23 @@ with prim::FusionGroup_1 = graph(%0 : Tensor
       %5 : Float(*, *)
       %6 : Tensor
       %7 : Float(*, *)) {
-  %8 : Float(*, *) = aten::neg(%7)
-  %9 : int = prim::Constant[value=1]()
-  %10 : Float(*, *) = aten::add(%8, %9, %9)
+  %8 : int = prim::Constant[value=1]()
+  %9 : Float(*, *) = aten::neg(%7)
+  %10 : Float(*, *) = aten::add(%9, %8, %8)
   %11 : Float(*, *) = aten::mul(%10, %7)
   %12 : Tensor = aten::mul(%11, %6)
   %13 : Float(*, *) = aten::mul(%5, %5)
   %14 : Float(*, *) = aten::neg(%13)
-  %15 : int = prim::Constant[value=1]()
-  %16 : Float(*, *) = aten::add(%14, %15, %15)
-  %17 : Tensor = aten::mul(%4, %16)
-  %18 : Float(*, *) = aten::neg(%3)
-  %19 : int = prim::Constant[value=1]()
-  %20 : Float(*, *) = aten::add(%18, %19, %19)
-  %21 : Float(*, *) = aten::mul(%20, %3)
-  %22 : Tensor = aten::mul(%21, %2)
-  %23 : Float(*, *) = aten::neg(%1)
-  %24 : int = prim::Constant[value=1]()
-  %25 : Float(*, *) = aten::add(%23, %24, %24)
-  %26 : Float(*, *) = aten::mul(%25, %1)
-  %27 : Tensor = aten::mul(%26, %0)
-  %28 : Tensor = prim::FusedConcat[dim=1](%27, %22, %17, %12)
-  return (%28);
+  %15 : Float(*, *) = aten::add(%14, %8, %8)
+  %16 : Tensor = aten::mul(%4, %15)
+  %17 : Float(*, *) = aten::neg(%3)
+  %18 : Float(*, *) = aten::add(%17, %8, %8)
+  %19 : Float(*, *) = aten::mul(%18, %3)
+  %20 : Tensor = aten::mul(%19, %2)
+  %21 : Float(*, *) = aten::neg(%1)
+  %22 : Float(*, *) = aten::add(%21, %8, %8)
+  %23 : Float(*, *) = aten::mul(%22, %1)
+  %24 : Tensor = aten::mul(%23, %0)
+  %25 : Tensor = prim::FusedConcat[dim=1](%24, %20, %16, %12)
+  return (%25);
 }
index 55fbd37..4cbf825 100644 (file)
@@ -49,37 +49,25 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
   %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
   %21 : int = prim::Constant[value=1]()
   %22 : Float(*, *) = aten::add(%13, %17, %21)
-  %23 : int = prim::Constant[value=1]()
-  %24 : Float(*, *) = aten::add(%14, %18, %23)
-  %25 : int = prim::Constant[value=1]()
-  %26 : Float(*, *) = aten::add(%15, %19, %25)
-  %27 : int = prim::Constant[value=1]()
-  %28 : Float(*, *) = aten::add(%16, %20, %27)
-  %29 : int = prim::Constant[value=1]()
-  %30 : Float(*, *) = aten::add(%22, %9, %29)
-  %31 : int = prim::Constant[value=1]()
-  %32 : Float(*, *) = aten::add(%24, %10, %31)
-  %33 : int = prim::Constant[value=1]()
-  %34 : Float(*, *) = aten::add(%26, %11, %33)
-  %35 : int = prim::Constant[value=1]()
-  %36 : Float(*, *) = aten::add(%28, %12, %35)
-  %37 : int = prim::Constant[value=1]()
-  %38 : Float(*, *) = aten::add(%30, %5, %37)
-  %39 : int = prim::Constant[value=1]()
-  %40 : Float(*, *) = aten::add(%32, %6, %39)
-  %41 : int = prim::Constant[value=1]()
-  %42 : Float(*, *) = aten::add(%34, %7, %41)
-  %43 : int = prim::Constant[value=1]()
-  %44 : Float(*, *) = aten::add(%36, %8, %43)
-  %ingate.1 : Float(*, *) = aten::sigmoid(%38)
-  %forgetgate.1 : Float(*, *) = aten::sigmoid(%40)
-  %cellgate.1 : Float(*, *) = aten::tanh(%42)
-  %outgate.1 : Float(*, *) = aten::sigmoid(%44)
-  %49 : Float(*, *) = aten::mul(%forgetgate.1, %0)
-  %50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
-  %51 : int = prim::Constant[value=1]()
-  %cy : Float(*, *) = aten::add(%49, %50, %51)
-  %53 : Float(*, *) = aten::tanh(%cy)
-  %hy : Float(*, *) = aten::mul(%outgate.1, %53)
-  return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+  %23 : Float(*, *) = aten::add(%14, %18, %21)
+  %24 : Float(*, *) = aten::add(%15, %19, %21)
+  %25 : Float(*, *) = aten::add(%16, %20, %21)
+  %26 : Float(*, *) = aten::add(%22, %9, %21)
+  %27 : Float(*, *) = aten::add(%23, %10, %21)
+  %28 : Float(*, *) = aten::add(%24, %11, %21)
+  %29 : Float(*, *) = aten::add(%25, %12, %21)
+  %30 : Float(*, *) = aten::add(%26, %5, %21)
+  %31 : Float(*, *) = aten::add(%27, %6, %21)
+  %32 : Float(*, *) = aten::add(%28, %7, %21)
+  %33 : Float(*, *) = aten::add(%29, %8, %21)
+  %ingate.1 : Float(*, *) = aten::sigmoid(%30)
+  %forgetgate.1 : Float(*, *) = aten::sigmoid(%31)
+  %cellgate.1 : Float(*, *) = aten::tanh(%32)
+  %outgate.1 : Float(*, *) = aten::sigmoid(%33)
+  %38 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+  %39 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+  %cy : Float(*, *) = aten::add(%38, %39, %21)
+  %41 : Float(*, *) = aten::tanh(%cy)
+  %hy : Float(*, *) = aten::mul(%outgate.1, %41)
+  return (%hy, %41, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
 }
index 3a02e32..651d2df 100644 (file)
@@ -27,37 +27,25 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
   %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
   %21 : int = prim::Constant[value=1]()
   %22 : Float(*, *) = aten::add(%13, %17, %21)
-  %23 : int = prim::Constant[value=1]()
-  %24 : Float(*, *) = aten::add(%14, %18, %23)
-  %25 : int = prim::Constant[value=1]()
-  %26 : Float(*, *) = aten::add(%15, %19, %25)
-  %27 : int = prim::Constant[value=1]()
-  %28 : Float(*, *) = aten::add(%16, %20, %27)
-  %29 : int = prim::Constant[value=1]()
-  %30 : Float(*, *) = aten::add(%5, %9, %29)
-  %31 : int = prim::Constant[value=1]()
-  %32 : Float(*, *) = aten::add(%6, %10, %31)
-  %33 : int = prim::Constant[value=1]()
-  %34 : Float(*, *) = aten::add(%7, %11, %33)
-  %35 : int = prim::Constant[value=1]()
-  %36 : Float(*, *) = aten::add(%8, %12, %35)
-  %37 : int = prim::Constant[value=1]()
-  %38 : Float(*, *) = aten::add(%30, %22, %37)
-  %39 : int = prim::Constant[value=1]()
-  %40 : Float(*, *) = aten::add(%32, %24, %39)
-  %41 : int = prim::Constant[value=1]()
-  %42 : Float(*, *) = aten::add(%34, %26, %41)
-  %43 : int = prim::Constant[value=1]()
-  %44 : Float(*, *) = aten::add(%36, %28, %43)
-  %ingate : Float(*, *) = aten::sigmoid(%38)
-  %forgetgate : Float(*, *) = aten::sigmoid(%40)
-  %cellgate : Float(*, *) = aten::tanh(%42)
-  %outgate : Float(*, *) = aten::sigmoid(%44)
-  %49 : Float(*, *) = aten::mul(%forgetgate, %0)
-  %50 : Float(*, *) = aten::mul(%ingate, %cellgate)
-  %51 : int = prim::Constant[value=1]()
-  %cy : Float(*, *) = aten::add(%49, %50, %51)
-  %53 : Float(*, *) = aten::tanh(%cy)
-  %54 : Float(*, *) = aten::mul(%outgate, %53)
-  return (%54, %cy);
+  %23 : Float(*, *) = aten::add(%14, %18, %21)
+  %24 : Float(*, *) = aten::add(%15, %19, %21)
+  %25 : Float(*, *) = aten::add(%16, %20, %21)
+  %26 : Float(*, *) = aten::add(%5, %9, %21)
+  %27 : Float(*, *) = aten::add(%6, %10, %21)
+  %28 : Float(*, *) = aten::add(%7, %11, %21)
+  %29 : Float(*, *) = aten::add(%8, %12, %21)
+  %30 : Float(*, *) = aten::add(%26, %22, %21)
+  %31 : Float(*, *) = aten::add(%27, %23, %21)
+  %32 : Float(*, *) = aten::add(%28, %24, %21)
+  %33 : Float(*, *) = aten::add(%29, %25, %21)
+  %ingate : Float(*, *) = aten::sigmoid(%30)
+  %forgetgate : Float(*, *) = aten::sigmoid(%31)
+  %cellgate : Float(*, *) = aten::tanh(%32)
+  %outgate : Float(*, *) = aten::sigmoid(%33)
+  %38 : Float(*, *) = aten::mul(%forgetgate, %0)
+  %39 : Float(*, *) = aten::mul(%ingate, %cellgate)
+  %cy : Float(*, *) = aten::add(%38, %39, %21)
+  %41 : Float(*, *) = aten::tanh(%cy)
+  %42 : Float(*, *) = aten::mul(%outgate, %41)
+  return (%42, %cy);
 }
index 7f45d2c..e664021 100644 (file)
@@ -89,14 +89,13 @@ graph(%0 : Float(*, *)
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
       %2 : Float(*, *)) {
-  %3 : Float(*, *) = aten::mul(%2, %2)
-  %4 : Float(*, *) = aten::neg(%3)
-  %5 : int = prim::Constant[value=1]()
-  %6 : Float(*, *) = aten::add(%4, %5, %5)
+  %3 : int = prim::Constant[value=1]()
+  %4 : Float(*, *) = aten::mul(%2, %2)
+  %5 : Float(*, *) = aten::neg(%4)
+  %6 : Float(*, *) = aten::add(%5, %3, %3)
   %7 : Tensor = aten::mul(%1, %6)
-  %8 : int = prim::Constant[value=1]()
-  %9 : Tensor = aten::add(%0, %7, %8)
-  return (%9);
+  %8 : Tensor = aten::add(%0, %7, %3)
+  return (%8);
 }
 with prim::FusionGroup_1 = graph(%0 : Tensor
       %1 : Float(*, *)
@@ -106,26 +105,23 @@ with prim::FusionGroup_1 = graph(%0 : Tensor
       %5 : Float(*, *)
       %6 : Tensor
       %7 : Float(*, *)) {
-  %8 : Float(*, *) = aten::neg(%7)
-  %9 : int = prim::Constant[value=1]()
-  %10 : Float(*, *) = aten::add(%8, %9, %9)
+  %8 : int = prim::Constant[value=1]()
+  %9 : Float(*, *) = aten::neg(%7)
+  %10 : Float(*, *) = aten::add(%9, %8, %8)
   %11 : Float(*, *) = aten::mul(%10, %7)
   %12 : Tensor = aten::mul(%11, %6)
   %13 : Float(*, *) = aten::mul(%5, %5)
   %14 : Float(*, *) = aten::neg(%13)
-  %15 : int = prim::Constant[value=1]()
-  %16 : Float(*, *) = aten::add(%14, %15, %15)
-  %17 : Tensor = aten::mul(%4, %16)
-  %18 : Float(*, *) = aten::neg(%3)
-  %19 : int = prim::Constant[value=1]()
-  %20 : Float(*, *) = aten::add(%18, %19, %19)
-  %21 : Float(*, *) = aten::mul(%20, %3)
-  %22 : Tensor = aten::mul(%21, %2)
-  %23 : Float(*, *) = aten::neg(%1)
-  %24 : int = prim::Constant[value=1]()
-  %25 : Float(*, *) = aten::add(%23, %24, %24)
-  %26 : Float(*, *) = aten::mul(%25, %1)
-  %27 : Tensor = aten::mul(%26, %0)
-  %28 : Tensor = prim::FusedConcat[dim=1](%27, %22, %17, %12)
-  return (%28);
+  %15 : Float(*, *) = aten::add(%14, %8, %8)
+  %16 : Tensor = aten::mul(%4, %15)
+  %17 : Float(*, *) = aten::neg(%3)
+  %18 : Float(*, *) = aten::add(%17, %8, %8)
+  %19 : Float(*, *) = aten::mul(%18, %3)
+  %20 : Tensor = aten::mul(%19, %2)
+  %21 : Float(*, *) = aten::neg(%1)
+  %22 : Float(*, *) = aten::add(%21, %8, %8)
+  %23 : Float(*, *) = aten::mul(%22, %1)
+  %24 : Tensor = aten::mul(%23, %0)
+  %25 : Tensor = prim::FusedConcat[dim=1](%24, %20, %16, %12)
+  return (%25);
 }
index dd68b64..b900b51 100644 (file)
@@ -61,51 +61,39 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
   %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
   %23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
   %27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
-  %31 : Float(*, *) = aten::mul(%23, %27)
-  %32 : Float(*, *) = aten::mul(%24, %28)
-  %33 : Float(*, *) = aten::mul(%25, %29)
-  %34 : Float(*, *) = aten::mul(%26, %30)
-  %35 : Float(*, *) = aten::mul(%19, %15)
-  %36 : Float(*, *) = aten::mul(%20, %16)
-  %37 : Float(*, *) = aten::mul(%21, %17)
-  %38 : Float(*, *) = aten::mul(%22, %18)
-  %39 : Float(*, *) = aten::mul(%11, %15)
-  %40 : Float(*, *) = aten::mul(%12, %16)
-  %41 : Float(*, *) = aten::mul(%13, %17)
-  %42 : Float(*, *) = aten::mul(%14, %18)
-  %43 : int = prim::Constant[value=1]()
-  %44 : Float(*, *) = aten::add(%35, %31, %43)
-  %45 : int = prim::Constant[value=1]()
-  %46 : Float(*, *) = aten::add(%36, %32, %45)
-  %47 : int = prim::Constant[value=1]()
-  %48 : Float(*, *) = aten::add(%37, %33, %47)
-  %49 : int = prim::Constant[value=1]()
-  %50 : Float(*, *) = aten::add(%38, %34, %49)
-  %51 : int = prim::Constant[value=1]()
-  %52 : Float(*, *) = aten::add(%44, %39, %51)
-  %53 : int = prim::Constant[value=1]()
-  %54 : Float(*, *) = aten::add(%46, %40, %53)
-  %55 : int = prim::Constant[value=1]()
-  %56 : Float(*, *) = aten::add(%48, %41, %55)
-  %57 : int = prim::Constant[value=1]()
-  %58 : Float(*, *) = aten::add(%50, %42, %57)
-  %59 : int = prim::Constant[value=1]()
-  %60 : Float(*, *) = aten::add(%52, %7, %59)
-  %61 : int = prim::Constant[value=1]()
-  %62 : Float(*, *) = aten::add(%54, %8, %61)
-  %63 : int = prim::Constant[value=1]()
-  %64 : Float(*, *) = aten::add(%56, %9, %63)
-  %65 : int = prim::Constant[value=1]()
-  %66 : Float(*, *) = aten::add(%58, %10, %65)
-  %ingate.1 : Float(*, *) = aten::sigmoid(%60)
-  %forgetgate.1 : Float(*, *) = aten::sigmoid(%62)
-  %cellgate.1 : Float(*, *) = aten::tanh(%64)
-  %outgate.1 : Float(*, *) = aten::sigmoid(%66)
-  %71 : Float(*, *) = aten::mul(%forgetgate.1, %0)
-  %72 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
-  %73 : int = prim::Constant[value=1]()
-  %cy : Float(*, *) = aten::add(%71, %72, %73)
-  %75 : Float(*, *) = aten::tanh(%cy)
-  %hy : Float(*, *) = aten::mul(%outgate.1, %75)
-  return (%hy, %75, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+  %31 : int = prim::Constant[value=1]()
+  %32 : Float(*, *) = aten::mul(%23, %27)
+  %33 : Float(*, *) = aten::mul(%24, %28)
+  %34 : Float(*, *) = aten::mul(%25, %29)
+  %35 : Float(*, *) = aten::mul(%26, %30)
+  %36 : Float(*, *) = aten::mul(%19, %15)
+  %37 : Float(*, *) = aten::mul(%20, %16)
+  %38 : Float(*, *) = aten::mul(%21, %17)
+  %39 : Float(*, *) = aten::mul(%22, %18)
+  %40 : Float(*, *) = aten::mul(%11, %15)
+  %41 : Float(*, *) = aten::mul(%12, %16)
+  %42 : Float(*, *) = aten::mul(%13, %17)
+  %43 : Float(*, *) = aten::mul(%14, %18)
+  %44 : Float(*, *) = aten::add(%36, %32, %31)
+  %45 : Float(*, *) = aten::add(%37, %33, %31)
+  %46 : Float(*, *) = aten::add(%38, %34, %31)
+  %47 : Float(*, *) = aten::add(%39, %35, %31)
+  %48 : Float(*, *) = aten::add(%44, %40, %31)
+  %49 : Float(*, *) = aten::add(%45, %41, %31)
+  %50 : Float(*, *) = aten::add(%46, %42, %31)
+  %51 : Float(*, *) = aten::add(%47, %43, %31)
+  %52 : Float(*, *) = aten::add(%48, %7, %31)
+  %53 : Float(*, *) = aten::add(%49, %8, %31)
+  %54 : Float(*, *) = aten::add(%50, %9, %31)
+  %55 : Float(*, *) = aten::add(%51, %10, %31)
+  %ingate.1 : Float(*, *) = aten::sigmoid(%52)
+  %forgetgate.1 : Float(*, *) = aten::sigmoid(%53)
+  %cellgate.1 : Float(*, *) = aten::tanh(%54)
+  %outgate.1 : Float(*, *) = aten::sigmoid(%55)
+  %60 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+  %61 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+  %cy : Float(*, *) = aten::add(%60, %61, %31)
+  %63 : Float(*, *) = aten::tanh(%cy)
+  %hy : Float(*, *) = aten::mul(%outgate.1, %63)
+  return (%hy, %63, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
 }
index f5849f4..ea2587e 100644 (file)
@@ -5,9 +5,9 @@ graph(%x : Float()
 }
 with prim::FusionGroup_0 = graph(%0 : Float()
       %1 : Float()) {
-  %2 : int = prim::Constant[value=2]()
-  %3 : Float() = aten::mul(%1, %2)
-  %4 : int = prim::Constant[value=1]()
-  %5 : Float() = aten::add(%3, %0, %4)
+  %2 : int = prim::Constant[value=1]()
+  %3 : int = prim::Constant[value=2]()
+  %4 : Float() = aten::mul(%1, %3)
+  %5 : Float() = aten::add(%4, %0, %2)
   return (%5);
 }
index cb69295..39ae5bc 100644 (file)
@@ -8,6 +8,7 @@
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
 #include <torch/csrc/jit/script/compiler.h>
@@ -211,18 +212,6 @@ struct GraphFuser {
     return true;
   }
 
-  // Can this node produce an _output_ of a fusion group?
-  // all Fusable nodes can do this, but additionally Concat, which normally
-  // cannot be fused because it is not a simple map, can be put in a fusion
-  // group as long as no items in the group read the output of concat
-  bool isFusableAsExitNode(Node* node) {
-    return isFusableMap(node) || isFusableOnlyAsExitNode(node);
-  }
-
-  bool isFusableOnlyAsExitNode(Node* node) {
-    return isFusableCatNode(node) || node->kind() == prim::FusedConcat;
-  }
-
   bool calculatesSize(Node* node) {
     return node->matches("aten::size(Tensor self) -> int[]");
   }
@@ -238,15 +227,6 @@ struct GraphFuser {
     return true;
   }
 
-  bool mustRemainAsFusionGroupOutput(Value* producer) {
-    if (producer->node()->kind() != prim::FusionGroup) {
-      return false;
-    }
-    auto subgraph = producer->node()->g(attr::Subgraph);
-    auto* node = subgraph->outputs().at(producer->offset())->node();
-    return isFusableOnlyAsExitNode(node);
-  }
-
   Graph& getSubgraph(Node* n) {
     JIT_ASSERT(n->kind() == prim::FusionGroup);
     return *n->g(attr::Subgraph);
@@ -457,41 +437,20 @@ struct GraphFuser {
     // we can move the consumer up into the producer.
     // but this requires better handling of merging fusion groups so it is not
     // done now
-    Node* real_consumer = consumer->kind() == aten::cat
-        ? consumer->namedInput(attr::tensors)->node()
-        : consumer;
     bool shouldFuse = isFusable(producer->node()) &&
         // Rearrange nodes such that all uses of producer are after the
         // consumer. Fusion will rewrite those later uses to use the version of
         // producer generated by the fused blob. In this case, producer becomes
         // an output of the fusion group.
         producer->node()->moveBeforeTopologicallyValid(
-            real_consumer, aliasDb_.value());
+            consumer, aliasDb_.value());
 
     if (!shouldFuse) {
       return at::nullopt;
     }
 
     auto group = consumer;
-    if (consumer->kind() == aten::cat) {
-      Graph* graph = consumer->owningGraph();
-      Node* list_construct = consumer->namedInput(attr::tensors)->node();
-      int64_t dim = consumer->get<int64_t>(attr::dim).value();
-
-      Node* fused_cat =
-          graph->create(prim::FusedConcat, list_construct->inputs())
-              ->i_(attr::dim, dim);
-      fused_cat->insertBefore(list_construct);
-      fused_cat->output()->copyMetadata(consumer->output());
-      consumer->output()->replaceAllUsesWith(fused_cat->output());
-
-      // NB: this deletes the fused_cat node from the original graph
-      group = createSingletonFusionGroup(fused_cat);
-      consumer->destroy();
-      if (list_construct->output()->uses().empty()) {
-        list_construct->destroy();
-      }
-    } else if (consumer->kind() != prim::FusionGroup) {
+    if (consumer->kind() != prim::FusionGroup) {
       group = createSingletonFusionGroup(consumer);
     }
     if (producer->node()->matches(
@@ -884,18 +843,12 @@ struct GraphFuser {
 
   // returns where to continue scanning, and whether any fusion was made
   std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
-    if (isFusableAsExitNode(consumer)) {
-      auto consumer_inputs = consumer->kind() == aten::cat
-          ? consumer->namedInput(attr::tensors)->node()->inputs()
-          : consumer->inputs();
+    if (isFusable(consumer)) {
       // handle inputs in reverse topological order as well...
       // otherwise in f(a,a+b) it will appear a is used twice if we consider
       // the f-a fusion before the f-(a+b) fusion first.
-      auto inputs = sortReverseTopological(consumer_inputs);
+      auto inputs = sortReverseTopological(consumer->inputs());
       for (auto producer : inputs) {
-        // Don't fuse if producer must come from a FusionGroup exit node
-        if (mustRemainAsFusionGroupOutput(producer))
-          continue;
         if (tryToMoveChunk(consumer, producer)) {
           // the chunk before this consumer was re-arranged to allow fusion,
           // we scan this consumer again to perform the fusion
@@ -1056,6 +1009,94 @@ struct GraphFuser {
     aliasDb_ = AliasAnalysis(graph_);
   }
 
+  bool canFuseWithConcat(Value* producer, Node* before_check) {
+    if (!isFusable(producer->node())) {
+      return false;
+    }
+    // NB: it is important that this check happens after isFusable, which checks
+    // that the blocks match, and it's not a special node like prim::Param
+    if (!producer->node()->couldMoveBeforeTopologically(
+            before_check, aliasDb_.value())) {
+      return false;
+    }
+    // Fusion groups can be merged with concat's group if and only if
+    // the value they produce isn't already coming from a concat.
+    if (producer->node()->kind() == prim::FusionGroup) {
+      auto subgraph = producer->node()->g(attr::Subgraph);
+      auto* node = subgraph->outputs().at(producer->offset())->node();
+      return node->kind() != prim::FusedConcat;
+    }
+    return true;
+  }
+
+  Node* createFusedConcat(Node* node) {
+    JIT_ASSERT(node->kind() == aten::cat);
+
+    Graph* graph = node->owningGraph();
+    Node* list_construct = node->namedInput(attr::tensors)->node();
+    int64_t dim = node->get<int64_t>(attr::dim).value();
+
+    Node* fused_cat = graph->create(prim::FusedConcat, list_construct->inputs())
+                          ->i_(attr::dim, dim);
+    fused_cat->insertBefore(list_construct);
+    fused_cat->output()->copyMetadata(node->output());
+
+    // NB: this deletes the fused_cat node from the original graph
+    return createSingletonFusionGroup(fused_cat);
+  }
+
+  void fuseConcats() {
+    for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();
+         ++it) {
+      Node* cat = *it;
+      if (!isFusableCatNode(cat)) {
+        continue;
+      }
+      Node* list_construct = cat->namedInput(attr::tensors)->node();
+      Node* fused_cat = createFusedConcat(cat);
+      Value* fused_cat_out = fused_cat->output();
+
+      auto sorted_inputs = sortReverseTopological(fused_cat->inputs());
+      size_t input_idx = 0;
+      bool any_fused = false;
+      while (input_idx < sorted_inputs.size()) {
+        Value* input = sorted_inputs[input_idx++];
+        if (!canFuseWithConcat(input, fused_cat)) {
+          continue;
+        }
+        any_fused = true;
+        auto maybe_group = tryFuse(fused_cat, input);
+        JIT_ASSERT(maybe_group && maybe_group == fused_cat);
+        // We could have destroyed multiple inputs when performing this fusion,
+        // so we have to recompute the list and iterate over it again.
+        sorted_inputs = sortReverseTopological(fused_cat->inputs());
+        input_idx = 0;
+      }
+
+      if (any_fused) {
+        cat->output()->replaceAllUsesWith(fused_cat_out);
+        it.destroyCurrent();
+        if (list_construct->output()->uses().empty()) {
+          list_construct->destroy();
+        }
+      } else {
+        fused_cat->destroy();
+      }
+    }
+  }
+
+  void optimizeFusedGraphs() {
+    for (Node* node : block_->nodes()) {
+      if (node->kind() != prim::FusionGroup) {
+        continue;
+      }
+      auto subgraph = node->g(attr::Subgraph);
+      EliminateDeadCode(subgraph);
+      EliminateCommonSubexpression(subgraph);
+      ConstantPooling(subgraph);
+    }
+  }
+
   void run() {
     // Run the pass until no changes are made.
     // This is neccessary, because the algorithm can miss out on certain fusion
@@ -1083,6 +1124,11 @@ struct GraphFuser {
         any_changed |= changed;
       }
     }
+    refreshAliasDb();
+
+    fuseConcats();
+
+    optimizeFusedGraphs();
 
     // The graph fuser can add intermediate prim::BroadcastingChunk nodes.
     // Replace them with broadcasts + chunks.