From d76fd43294693be5a5a38251051345f2150d421c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 4 Dec 2018 15:40:41 -0800 Subject: [PATCH] Reenable all forward-pass fusions that worked before the AD fix (#14558) Summary: Dealing with so many `aten::size` calls (in particular calls on elements computed inside fusion groups) requires us to do some extra graph processing in the fuser (to compute the sizes by explicit broadcasts, instead of writing the intermediate tensors only to check their size). This restores the forward expects of LSTM and MiLSTM to a single big kernel. Unfortunately the backward is much harder, because as long as we can't prove that the reductions are unnecessary (or if we can't distribute them over the op), we will not be able to fuse them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14558 Differential Revision: D13321748 Pulled By: zou3519 fbshipit-source-id: c04fc2f70d106d2bfb56206b5aec517a93b79d1f --- aten/src/ATen/core/interned_strings.h | 2 + ...TestScript.test_lstm_fusion_cuda-forward.expect | 113 +++++++------ ...stScript.test_milstm_fusion_cuda-forward.expect | 144 +++++++++------- test/test_jit.py | 10 +- torch/csrc/jit/fuser/executor.cpp | 5 +- torch/csrc/jit/passes/alias_analysis.cpp | 2 + torch/csrc/jit/passes/graph_fuser.cpp | 188 ++++++++++++++++++++- torch/csrc/jit/passes/python_print.cpp | 2 + torch/csrc/jit/register_prim_ops.cpp | 43 +++++ 9 files changed, 382 insertions(+), 127 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 477d392..ec0e045 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -20,7 +20,9 @@ namespace c10 { _(namespaces, namespaces) \ _(prim, Assign) \ _(prim, BroadcastingChunk) \ + _(prim, BroadcastSizes) \ _(prim, Constant) \ + _(prim, ChunkSizes) \ _(prim, None) \ _(prim, Drop) \ _(prim, Eval) \ diff --git a/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect b/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect index a10d0c3..825e802 100644 --- a/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect +++ b/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect @@ -19,59 +19,68 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) %8 : Float(*, *) = aten::mm(%5, %7) %9 : Float(*, *) = aten::t(%4) %10 : Float(*, *) = aten::mm(%3, %9) - %11 : int = prim::Constant[value=1]() - %12 : int[] = aten::size(%8) - %13 : int[] = aten::size(%10) - %14 : int[] = aten::size(%2) - %15 : Float(*, *), %16 : Float(*, *) = prim::FusionGroup_0(%2, %8, %10) - %17 : int[] = aten::size(%16) - %18 : int[] = aten::size(%15) - %19 : int[] = aten::size(%1) - %20 : Tensor[] = prim::ListConstruct(%15, %1) - %21 : Tensor[] = aten::broadcast_tensors(%20) - %22 : Tensor, %23 : Tensor = prim::ListUnpack(%21) - %24 : int[] = aten::size(%0) - %hy : Float(*, *), %26 : Float(*, *), %cy : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %23, %22) - %34 : int[] = aten::size(%ingate.1) - %35 : int[] = aten::size(%forgetgate.1) - %36 : int[] = aten::size(%cellgate.1) - %37 : int[] = aten::size(%outgate.1) - %38 : int[] = aten::size(%29) - %39 : int[] = aten::size(%28) - %40 : int[] = aten::size(%26) - return (%hy, %cy, %7, %9, %12, %13, %17, %14, %18, %19, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %35, %24, %34, %36, %38, %39, %26, %37, %40); + %11 : int[] = aten::size(%8) + %12 : int[] = aten::size(%10) + %13 : int[] = aten::size(%2) + %14 : int[] = aten::size(%1) + %15 : Tensor[] = prim::ListConstruct(%1, %2, %8, %10) + %16 : Tensor[] = aten::broadcast_tensors(%15) + %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16) + %21 : int[] = prim::BroadcastSizes(%11, %12) + %22 : int[] = prim::BroadcastSizes(%21, %13) + %23 : int[] = aten::size(%0) + %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17) + %31 : int[] = aten::size(%25) + %32 : int[] = aten::size(%outgate.1) + %33 : int[] = aten::size(%cellgate.1) + %34 : int[] = aten::size(%forgetgate.1) + %35 : int[] = aten::size(%ingate.1) + %36 : int[] = prim::BroadcastSizes(%34, %23) + %37 : int[] = prim::BroadcastSizes(%35, %33) + return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %23, %35, %33, %36, %37, %25, %32, %31); } -with prim::FusionGroup_0 = graph(%0 : Float(*) - %1 : Float(*, *) - %2 : Float(*, *)) { - %3 : int = prim::Constant[value=1]() - %4 : Float(*, *) = aten::add(%1, %2, %3) - %5 : int = prim::Constant[value=1]() - %6 : Float(*, *) = aten::add(%4, %0, %5) - return (%6, %4); -} -with prim::FusionGroup_1 = graph(%0 : Float(*, *) +with prim::FusionGroup_0 = graph(%0 : Float(*, *) %1 : Tensor - %2 : Tensor) { - %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) - %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) - %11 : int = prim::Constant[value=1]() - %12 : Float(*, *) = aten::add(%3, %7, %11) - %13 : int = prim::Constant[value=1]() - %14 : Float(*, *) = aten::add(%4, %8, %13) - %15 : int = prim::Constant[value=1]() - %16 : Float(*, *) = aten::add(%5, %9, %15) - %17 : int = prim::Constant[value=1]() - %18 : Float(*, *) = aten::add(%6, %10, %17) - %ingate.1 : Float(*, *) = aten::sigmoid(%12) - %forgetgate.1 : Float(*, *) = aten::sigmoid(%14) - %cellgate.1 : Float(*, *) = aten::tanh(%16) - %outgate.1 : Float(*, *) = aten::sigmoid(%18) - %23 : Float(*, *) = aten::mul(%forgetgate.1, %0) - %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) + %2 : Tensor + %3 : Tensor + %4 : Tensor) { + %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4) + %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3) + %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) + %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) + %21 : int = prim::Constant[value=1]() + %22 : Float(*, *) = aten::add(%13, %17, %21) + %23 : int = prim::Constant[value=1]() + %24 : Float(*, *) = aten::add(%14, %18, %23) %25 : int = prim::Constant[value=1]() - %cy : Float(*, *) = aten::add(%23, %24, %25) - %27 : Float(*, *) = aten::tanh(%cy) - %hy : Float(*, *) = aten::mul(%outgate.1, %27) - return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); + %26 : Float(*, *) = aten::add(%15, %19, %25) + %27 : int = prim::Constant[value=1]() + %28 : Float(*, *) = aten::add(%16, %20, %27) + %29 : int = prim::Constant[value=1]() + %30 : Float(*, *) = aten::add(%22, %9, %29) + %31 : int = prim::Constant[value=1]() + %32 : Float(*, *) = aten::add(%24, %10, %31) + %33 : int = prim::Constant[value=1]() + %34 : Float(*, *) = aten::add(%26, %11, %33) + %35 : int = prim::Constant[value=1]() + %36 : Float(*, *) = aten::add(%28, %12, %35) + %37 : int = prim::Constant[value=1]() + %38 : Float(*, *) = aten::add(%30, %5, %37) + %39 : int = prim::Constant[value=1]() + %40 : Float(*, *) = aten::add(%32, %6, %39) + %41 : int = prim::Constant[value=1]() + %42 : Float(*, *) = aten::add(%34, %7, %41) + %43 : int = prim::Constant[value=1]() + %44 : Float(*, *) = aten::add(%36, %8, %43) + %ingate.1 : Float(*, *) = aten::sigmoid(%38) + %forgetgate.1 : Float(*, *) = aten::sigmoid(%40) + %cellgate.1 : Float(*, *) = aten::tanh(%42) + %outgate.1 : Float(*, *) = aten::sigmoid(%44) + %49 : Float(*, *) = aten::mul(%forgetgate.1, %0) + %50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) + %51 : int = prim::Constant[value=1]() + %cy : Float(*, *) = aten::add(%49, %50, %51) + %53 : Float(*, *) = aten::tanh(%cy) + %hy : Float(*, *) = aten::mul(%outgate.1, %53) + return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); } diff --git a/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect b/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect index 2e2726c..1c4e483 100644 --- a/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect +++ b/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect @@ -25,69 +25,89 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) %Uz.1 : Float(*, *) = aten::mm(%5, %11) %13 : int[] = aten::size(%4) %14 : int[] = aten::size(%Wx.1) - %15 : int[] = aten::size(%Uz.1) - %16 : int[] = aten::size(%3) - %17 : int = prim::Constant[value=1]() - %18 : int[] = aten::size(%2) - %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *), %23 : Float(*, *), %24 : Float(*, *) = prim::FusionGroup_0(%2, %Uz.1, %3, %Wx.1, %4) - %25 : int[] = aten::size(%24) - %26 : int[] = aten::size(%23) - %27 : int[] = aten::size(%22) - %28 : int[] = aten::size(%21) - %29 : int[] = aten::size(%20) - %30 : int[] = aten::size(%19) - %31 : int[] = aten::size(%1) - %32 : Tensor[] = prim::ListConstruct(%19, %1) - %33 : Tensor[] = aten::broadcast_tensors(%32) - %34 : Tensor, %35 : Tensor = prim::ListUnpack(%33) - %36 : int[] = aten::size(%0) - %hy : Float(*, *), %38 : Float(*, *), %cy : Float(*, *), %40 : Float(*, *), %41 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %35, %34) + %15 : Float(*, *) = aten::mul(%4, %Wx.1) + %16 : int[] = aten::size(%15) + %17 : int[] = aten::size(%Uz.1) + %18 : int[] = aten::size(%3) + %19 : int[] = aten::size(%2) + %20 : int[] = aten::size(%1) + %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %15, %3, %Wx.1) + %22 : Tensor[] = aten::broadcast_tensors(%21) + %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22) + %29 : int[] = prim::BroadcastSizes(%18, %14) + %30 : int[] = prim::BroadcastSizes(%16, %17) + %31 : int[] = prim::BroadcastSizes(%19, %17) + %32 : int[] = prim::BroadcastSizes(%30, %29) + %33 : int[] = prim::BroadcastSizes(%32, %31) + %34 : int[] = aten::size(%0) + %hy : Float(*, *), %36 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23) + %42 : int[] = aten::size(%36) + %43 : int[] = aten::size(%outgate.1) + %44 : int[] = aten::size(%cellgate.1) + %45 : int[] = aten::size(%forgetgate.1) %46 : int[] = aten::size(%ingate.1) - %47 : int[] = aten::size(%forgetgate.1) - %48 : int[] = aten::size(%cellgate.1) - %49 : int[] = aten::size(%outgate.1) - %50 : int[] = aten::size(%41) - %51 : int[] = aten::size(%40) - %52 : int[] = aten::size(%38) - return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %24, %25, %15, %16, %14, %26, %27, %18, %15, %28, %29, %30, %31, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %47, %36, %46, %48, %50, %51, %38, %49, %52); + %47 : int[] = prim::BroadcastSizes(%45, %34) + %48 : int[] = prim::BroadcastSizes(%46, %44) + return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %15, %16, %17, %18, %14, %30, %29, %19, %17, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %34, %46, %44, %47, %48, %36, %43, %42); } -with prim::FusionGroup_0 = graph(%0 : Float(*) - %1 : Float(*, *) - %2 : Float(*) - %3 : Float(*, *) - %4 : Float(*)) { - %5 : Float(*, *) = aten::mul(%4, %3) - %6 : Float(*, *) = aten::mul(%5, %1) - %7 : Float(*, *) = aten::mul(%2, %3) - %8 : int = prim::Constant[value=1]() - %9 : Float(*, *) = aten::add(%6, %7, %8) - %10 : Float(*, *) = aten::mul(%0, %1) - %11 : int = prim::Constant[value=1]() - %12 : Float(*, *) = aten::add(%9, %10, %11) - return (%12, %10, %9, %7, %6, %5); -} -with prim::FusionGroup_1 = graph(%0 : Float(*, *) +with prim::FusionGroup_0 = graph(%0 : Float(*, *) %1 : Tensor - %2 : Tensor) { - %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) - %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) - %11 : int = prim::Constant[value=1]() - %12 : Float(*, *) = aten::add(%3, %7, %11) - %13 : int = prim::Constant[value=1]() - %14 : Float(*, *) = aten::add(%4, %8, %13) - %15 : int = prim::Constant[value=1]() - %16 : Float(*, *) = aten::add(%5, %9, %15) - %17 : int = prim::Constant[value=1]() - %18 : Float(*, *) = aten::add(%6, %10, %17) - %ingate.1 : Float(*, *) = aten::sigmoid(%12) - %forgetgate.1 : Float(*, *) = aten::sigmoid(%14) - %cellgate.1 : Float(*, *) = aten::tanh(%16) - %outgate.1 : Float(*, *) = aten::sigmoid(%18) - %23 : Float(*, *) = aten::mul(%forgetgate.1, %0) - %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) - %25 : int = prim::Constant[value=1]() - %cy : Float(*, *) = aten::add(%23, %24, %25) - %27 : Float(*, *) = aten::tanh(%cy) - %hy : Float(*, *) = aten::mul(%outgate.1, %27) - return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); + %2 : Tensor + %3 : Tensor + %4 : Tensor + %5 : Tensor + %6 : Tensor) { + %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%6) + %11 : Float(*, *), %12 : Float(*, *), %13 : Float(*, *), %14 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%5) + %15 : Float(*, *), %16 : Float(*, *), %17 : Float(*, *), %18 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4) + %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3) + %23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) + %27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) + %31 : Float(*, *) = aten::mul(%23, %27) + %32 : Float(*, *) = aten::mul(%24, %28) + %33 : Float(*, *) = aten::mul(%25, %29) + %34 : Float(*, *) = aten::mul(%26, %30) + %35 : Float(*, *) = aten::mul(%19, %15) + %36 : Float(*, *) = aten::mul(%20, %16) + %37 : Float(*, *) = aten::mul(%21, %17) + %38 : Float(*, *) = aten::mul(%22, %18) + %39 : Float(*, *) = aten::mul(%11, %15) + %40 : Float(*, *) = aten::mul(%12, %16) + %41 : Float(*, *) = aten::mul(%13, %17) + %42 : Float(*, *) = aten::mul(%14, %18) + %43 : int = prim::Constant[value=1]() + %44 : Float(*, *) = aten::add(%35, %31, %43) + %45 : int = prim::Constant[value=1]() + %46 : Float(*, *) = aten::add(%36, %32, %45) + %47 : int = prim::Constant[value=1]() + %48 : Float(*, *) = aten::add(%37, %33, %47) + %49 : int = prim::Constant[value=1]() + %50 : Float(*, *) = aten::add(%38, %34, %49) + %51 : int = prim::Constant[value=1]() + %52 : Float(*, *) = aten::add(%44, %39, %51) + %53 : int = prim::Constant[value=1]() + %54 : Float(*, *) = aten::add(%46, %40, %53) + %55 : int = prim::Constant[value=1]() + %56 : Float(*, *) = aten::add(%48, %41, %55) + %57 : int = prim::Constant[value=1]() + %58 : Float(*, *) = aten::add(%50, %42, %57) + %59 : int = prim::Constant[value=1]() + %60 : Float(*, *) = aten::add(%52, %7, %59) + %61 : int = prim::Constant[value=1]() + %62 : Float(*, *) = aten::add(%54, %8, %61) + %63 : int = prim::Constant[value=1]() + %64 : Float(*, *) = aten::add(%56, %9, %63) + %65 : int = prim::Constant[value=1]() + %66 : Float(*, *) = aten::add(%58, %10, %65) + %ingate.1 : Float(*, *) = aten::sigmoid(%60) + %forgetgate.1 : Float(*, *) = aten::sigmoid(%62) + %cellgate.1 : Float(*, *) = aten::tanh(%64) + %outgate.1 : Float(*, *) = aten::sigmoid(%66) + %71 : Float(*, *) = aten::mul(%forgetgate.1, %0) + %72 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) + %73 : int = prim::Constant[value=1]() + %cy : Float(*, *) = aten::add(%71, %72, %73) + %75 : Float(*, *) = aten::tanh(%cy) + %hy : Float(*, *) = aten::mul(%outgate.1, %75) + return (%hy, %75, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); } diff --git a/test/test_jit.py b/test/test_jit.py index 063ac01..a71e31f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4308,9 +4308,8 @@ a") inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) forward_graph = module.graph_for(*inputs) - with self.assertRaises(AssertionError): - self.assertGraphContainsExactly( - forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) + self.assertGraphContainsExactly( + forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) self.assertExpectedGraph(forward_graph, subname='forward') hy, cy = module(*inputs) @@ -4324,9 +4323,8 @@ a") inputs = get_milstm_inputs('cuda', training=True) module = self.checkScript(MiLSTMCell, inputs) forward_graph = module.graph_for(*inputs) - with self.assertRaises(AssertionError): - self.assertGraphContainsExactly( - forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) + self.assertGraphContainsExactly( + forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) self.assertExpectedGraph(forward_graph, subname='forward') hy, cy = module(*inputs) diff --git a/torch/csrc/jit/fuser/executor.cpp b/torch/csrc/jit/fuser/executor.cpp index 2149936..99b01ac 100644 --- a/torch/csrc/jit/fuser/executor.cpp +++ b/torch/csrc/jit/fuser/executor.cpp @@ -28,16 +28,13 @@ static c10::optional> getMapSize( , at::TensorList args , at::IntList arg_subset) { - int64_t dim_after_broadcast = 0; - for (const auto arg_idx : arg_subset) { - dim_after_broadcast = std::max(dim_after_broadcast, args[arg_idx].dim()); - } // TODO: this keeps reallocating map_size at every iteration, but we know // exactly how much storage do we need, so this could be fixed in-place at // every step. We're just missing a few functions for ATen, but the fix // should be straightforward. // Note: left unitialized since empty shape is broadcastable to any shape std::vector map_size; + map_size.reserve(8); for (const auto arg_idx : arg_subset) { auto& arg = args.at(arg_idx); auto& chunk_desc = spec.inputChunks().at(arg_idx); diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 5953aac..a47fdcc 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -279,6 +279,8 @@ void AliasDb::analyze(Node* node) { case prim::MMTreeReduce: case prim::MMBatchSide: case prim::None: + case prim::BroadcastSizes: + case prim::ChunkSizes: return analyzeCreator(node); case prim::TupleUnpack: case prim::TupleIndex: diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 85717a5..ab440f2 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -2,6 +2,7 @@ #include "torch/csrc/jit/passes/alias_analysis.h" #include "torch/csrc/jit/passes/common_subexpression_elimination.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/symbolic_variable.h" #include "torch/csrc/jit/fuser/interface.h" #include "torch/csrc/jit/operator.h" @@ -112,6 +113,14 @@ bool isSimpleMap(Node *node) { return true; } +Value * broadcastSizes(at::ArrayRef sizes) { + JIT_ASSERT(!sizes.empty()); + Graph * graph = sizes[0]->owningGraph(); + Node * broadcast_n = graph->insertNode(graph->create(prim::BroadcastSizes, sizes)); + broadcast_n->output()->setType(ListType::ofInts()); + return broadcast_n->output(); +} + struct GraphFuser { Block * block_; std::shared_ptr graph_; @@ -159,11 +168,15 @@ struct GraphFuser { return isFusableCatNode(node) || node->kind() == prim::FusedConcat; } - bool allUsersAreThisConsumer(Node * consumer, Value * producer) { + bool calculatesSize(Node * node) { + return node->matches("aten::size(Tensor self) -> int[]"); + } + + bool allUsersAreThisConsumerOrCalcSizes(Node * consumer, Value * producer) { auto defining_node = producer->node(); for(auto o : defining_node->outputs()) { for(auto u : o->uses()) { - if(u.user != consumer) + if(u.user != consumer && !calculatesSize(u.user)) return false; } } @@ -610,7 +623,7 @@ struct GraphFuser { chunk->inputs().end(), [&](Value * producer_for_chunk) { return isFusable(producer_for_chunk->node()) && - allUsersAreThisConsumer(chunk, producer_for_chunk); + allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk); }); if (it == chunk->inputs().end()) { return false; @@ -707,6 +720,22 @@ struct GraphFuser { for (size_t i = 0; i < nchunks; i++) { bchunk->eraseOutput(nchunks * producer_index); } + + // The output of producer_for_chunk_node could have been used in some aten::size + // operators, so we need to clean those up as well (we simply broadcast all its tensor inputs). + auto size_calc_uses = producer_for_chunk_node->output()->uses(); + if (!size_calc_uses.empty()) { + auto tensor_inputs = filter(producer_for_chunk_node->inputs(), + [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); }); + auto tensor_sizes = fmap(tensor_inputs, + [](Value * v) { return v->owningGraph()->insert(aten::size, {v}); }); + JIT_ASSERT(!tensor_sizes.empty()); + Value * output_size = tensor_sizes.size() == 1 ? tensor_sizes[0] : broadcastSizes(tensor_sizes); + for (Use u : size_calc_uses) { + u.user->output()->replaceAllUsesWith(output_size); + u.user->destroy(); + } + } producer_for_chunk_node->destroy(); return true; } @@ -773,6 +802,98 @@ struct GraphFuser { } } + bool usedOnlyInSize(Value * v) { + const auto & uses = v->uses(); + return std::all_of(uses.begin(), uses.end(), + [](const Use& u) { return u.user->matches("aten::size(Tensor self) -> int[]"); }); + } + + std::unordered_map buildShapeExpressions(Node * fusion_group) { + WithInsertPoint insert_guard { fusion_group->next() }; + std::unordered_map shape_of; + + Graph * graph = fusion_group->owningGraph(); + auto subgraph = fusion_group->g(attr::Subgraph); + + auto inputs = fusion_group->inputs(); + auto sinputs = subgraph->inputs(); + JIT_ASSERT(inputs.size() == sinputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]}); + } + + // When we have a guarantee that an output won't be removed, because it's + // used in expressions that don't involve size checks, we can use its size + // instead of computing a long chain of broadcasts, starting from the beginning + // of the kernel. + auto outputs = fusion_group->outputs(); + auto soutputs = subgraph->outputs(); + JIT_ASSERT(outputs.size() == soutputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (usedOnlyInSize(outputs[i])) continue; + shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]}); + } + + for (Node * n : subgraph->nodes()) { + // XXX: Use of shape_of.emplace is crucial to the output shape optimization! + if (n->kind() == prim::FusedConcat) { + // This is a bit more involved, because we have to account for the case + // when inputs have different shapes, but fortunately those tensors are + // always outputs, and so we can simply avoid replacing their queries, + // because it won't help us. + continue; + } + if (n->kind() == prim::Constant) { + continue; + } + if (n->kind() == prim::ConstantChunk) { + Node * sizes_node = graph->insertNode(graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2)); + sizes_node->i_(attr::dim, n->i(attr::dim)); + sizes_node->i_(attr::chunks, n->i(attr::chunks)); + Value * regular_size = sizes_node->outputs().at(0); + Value * last_size = sizes_node->outputs().at(1); + auto outputs = n->outputs(); + for (Value * o : outputs.slice(0, outputs.size() - 1)) { + shape_of.emplace(o, regular_size); + } + shape_of.emplace(outputs.at(outputs.size() - 1), last_size); + continue; + } + auto tensor_inputs = filter(n->inputs(), + [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); }); + auto shapes = fmap(tensor_inputs, [&](Value * v) { return shape_of.at(v); }); + JIT_ASSERT(!shapes.empty()); + shape_of.emplace(n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes)); + } + return shape_of; + } + + void removeOutputsUsedOnlyInSize(Node * fusion_group) { + if (fusion_group->kind() != prim::FusionGroup) return; + auto subgraph = fusion_group->g(attr::Subgraph); + + auto shape_of = buildShapeExpressions(fusion_group); + auto outputs = fusion_group->outputs().vec(); + auto soutputs = subgraph->outputs().vec(); + // XXX: Iterating in this order is not only good for performance reasons! + // It is also crucial for correctness (i has to reflect the current true + // index of outputs[i])! + for (int64_t i = static_cast(outputs.size()) - 1; i >= 0; --i) { + auto output = outputs[i]; + auto soutput = soutputs[i]; + if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) { + auto uses = output->uses(); + for (Use u : uses) { + JIT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]")); + u.user->output()->replaceAllUsesWith(shape_of.at(soutput)); + u.user->destroy(); + } + fusion_group->eraseOutput(i); + subgraph->eraseOutput(i); + } + } + } + void run() { // Run the pass until no changes are made. // This is neccessary, because the algorithm can miss out on certain fusion @@ -809,6 +930,12 @@ struct GraphFuser { for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { it = scanNodeForChunks(*it); } + + // Remove outputs that have been added only because we need their size + for (Node * n : block_->nodes()) { + removeOutputsUsedOnlyInSize(n); + } + for (Node * node : block_->nodes()) { for (Block * sub_block : node->blocks()) { GraphFuser(sub_block, graph_).run(); @@ -817,6 +944,56 @@ struct GraphFuser { } }; +void PeepholeOptimizeShapeExpressions(Block * block) { + auto nodes = block->nodes(); + for (auto it = nodes.begin(); it != nodes.end(); ++it) { + Node * node = *it; + for (Block * subblock : node->blocks()) { + PeepholeOptimizeShapeExpressions(subblock); + } + if (node->kind() == prim::BroadcastSizes) { + // Remove no-op broadcasts. + if (node->inputs().size() == 1) { + node->output()->replaceAllUsesWith(node->input()); + it.destroyCurrent(); + continue; + } + // Deduplicate inputs, but use their unique() values to ensure + // this process only depends on the graph. + std::map unique_to_value; + for (Value * input : node->inputs()) { + unique_to_value.emplace(input->unique(), input); + } + if (unique_to_value.size() != node->inputs().size()) { + std::vector inputs; + for (auto & entry : unique_to_value) { + inputs.push_back(entry.second); + } + if (inputs.size() == 1) { + node->output()->replaceAllUsesWith(inputs[0]); + } else { + WithInsertPoint insert_guard { node }; + node->output()->replaceAllUsesWith(broadcastSizes(inputs)); + } + it.destroyCurrent(); + --it; // Revisit the node with deduplicated inputs + continue; + } + // Remove compose simple chains of broadcasts into a single node. + const auto & uses = node->output()->uses(); + if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) { + Node * user = uses[0].user; + user->removeInput(uses[0].offset); + // NB: we don't care about deduplication in here, as we will visit user later. + for (Value * i : node->inputs()) { + user->addInput(i); + } + it.destroyCurrent(); + } + } + } +} + } // anonymous namespace void FuseGraph(std::shared_ptr& graph) { @@ -826,6 +1003,11 @@ void FuseGraph(std::shared_ptr& graph) { GraphFuser(graph->block(), graph).run(); // After FuseGraph some common subexpressions may come back EliminateCommonSubexpression(graph); + // We might have emitted a fair amount of useless shape propagating code, so + // remove it + EliminateDeadCode(graph); + // Improve the quality of shape propagation code that was left + PeepholeOptimizeShapeExpressions(graph->block()); #endif } diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 311b360..7f7e761 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -1012,6 +1012,8 @@ TORCH_API bool printerHasSpecialCaseFor(Symbol sym) { prim::AutogradAdd, // temporarily inserted by autograd prim::ConstantChunk, // optimization pass adds it prim::DifferentiableGraph, // optimization pass adds it + prim::BroadcastSizes, // optimization pass (fuser) adds it + prim::ChunkSizes, // optimization pass (fuser) adds it prim::Drop, // used in interpreter only prim::FusedConcat, // optimization pass adds it prim::FusionGroup, // optimization pass adds it diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 72ba0cc..7ed54c0 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -12,6 +12,7 @@ #include "torch/csrc/variable_tensor_functions.h" #include +#include #include #include @@ -287,6 +288,48 @@ RegisterOperators reg({ }; }), Operator( + prim::BroadcastSizes, + [](const Node* node) -> Operation { + size_t num_inputs = node->inputs().size(); + return [num_inputs](Stack& stack) { + std::vector size; + size.reserve(8); + for (size_t i = 0; i < num_inputs; ++i) { + size = at::infer_size(size, peek(stack, i, num_inputs).toIntList()->elements()); + } + drop(stack, num_inputs); + push(stack, std::move(size)); + return 0; + }; + }), + Operator( + prim::ChunkSizes, + [](const Node* node) -> Operation { + int64_t raw_dim = node->i(attr::dim); + int64_t chunks = node->i(attr::chunks); + return [raw_dim, chunks](Stack& stack) { + Shared sizes_l; + pop(stack, sizes_l); + const auto & shape = sizes_l->elements(); + std::vector regular_shape = shape; + std::vector last_shape = shape; + int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size()); + AT_CHECK(dim < regular_shape.size(), "Dimension out of range for chunk"); + int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks; + regular_shape[dim] = split_size; + if (shape[dim] % chunks == 0) { + last_shape[dim] = split_size; + } else { + int64_t num_splits = std::max((shape[dim] + split_size - 1) / split_size, 1); + last_shape[dim] = split_size - (split_size * num_splits - shape[dim]); + JIT_ASSERT(last_shape[dim] >= 0); + } + push(stack, std::move(regular_shape)); + push(stack, std::move(last_shape)); + return 0; + }; + }), + Operator( "prim::RaiseException(str msg) -> ()", [](const Node* node) -> Operation { return [](Stack& stack) { -- 2.7.4