From 86c96542914bf9b3dfda0c7f6373fd13b48c6b97 Mon Sep 17 00:00:00 2001 From: Salil Desai Date: Wed, 1 Sep 2021 14:08:02 -0700 Subject: [PATCH] Update optimize_for_mobile to preserve node's debug information (#63106) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63106 Propagate debug info to the re-written nodes in the graph. Test Plan: - Clone open source repo and build - ``` python3 test/test_jit.py TestOptimizeForMobilePreserveDebugInfo ``` - Tests pass Reviewed By: kimishpatel Differential Revision: D28654659 fbshipit-source-id: 2d7c87f2fb95a3be53246375f35639bbd97c237e --- ...test_optimize_for_mobile_preserve_debug_info.py | 261 +++++++++++++++++++++ test/test_jit.py | 1 + torch/csrc/jit/passes/xnnpack_rewrite.cpp | 165 ++++++++++--- 3 files changed, 388 insertions(+), 39 deletions(-) create mode 100644 test/jit/test_optimize_for_mobile_preserve_debug_info.py diff --git a/test/jit/test_optimize_for_mobile_preserve_debug_info.py b/test/jit/test_optimize_for_mobile_preserve_debug_info.py new file mode 100644 index 0000000..c08f3b5 --- /dev/null +++ b/test/jit/test_optimize_for_mobile_preserve_debug_info.py @@ -0,0 +1,261 @@ +import torch +import torch._C +import torch.backends.xnnpack +import torch.nn.functional as F +from torch.testing._internal.jit_utils import JitTestCase + +class TestOptimizeForMobilePreserveDebugInfo(JitTestCase): + def check_replacement( + self, + model, + replacements, + jit_pass, + ): + """ + model: Model which optimization is performed on + replacements: Dict mapping from nodes' kinds in the optimized model + to the kinds of nodes they replaced in the original model + jit_pass: Function to perform optimization + """ + + original_kinds = set(replacements.values()) + original_source_ranges = { + node.kind(): node.sourceRange() + for node in model.graph.nodes() + if node.kind() in original_kinds + } + + jit_pass(model._c) + + for node in model.graph.nodes(): + if node.kind() in replacements: + self.assertEqual( + node.sourceRange(), + original_source_ranges[replacements[node.kind()]], + ) + + def test_replace_conv1d_with_conv2d(self): + class TestConv1d(torch.nn.Module): + def __init__(self, weight, bias): + super(TestConv1d, self).__init__() + self.weight = weight + self.bias = bias + + def forward(self, x): + return F.conv1d(x, self.weight, self.bias) + + self.check_replacement( + model=torch.jit.script( + TestConv1d( + weight=torch.rand(3, 3, 3), + bias=torch.rand(3), + ), + ), + replacements={ + "prim::ListUnpack": "aten::conv1d", + "prim::ListConstruct": "aten::conv1d", + "aten::unsqueeze": "aten::conv1d", + "aten::conv2d": "aten::conv1d", + "aten::squeeze": "aten::conv1d", + }, + jit_pass=torch._C._jit_pass_transform_conv1d_to_conv2d, + ) + + def test_insert_pre_packed_linear_before_inline_and_conv_2d_op(self): + class TestPrepackedLinearBeforeInlineAndConv2dOp(torch.nn.Module): + def __init__( + self, + linear_weight, + linear_bias, + conv2d_weight, + conv2d_bias, + conv_transpose2d_weight, + conv_transpose2d_bias, + ): + super( + TestPrepackedLinearBeforeInlineAndConv2dOp, + self, + ).__init__() + self.linear_weight = linear_weight.float() + self.linear_bias = linear_bias.float() + self.conv2d_weight = conv2d_weight.float() + self.conv2d_bias = conv2d_bias.float() + self.conv_transpose2d_weight = conv_transpose2d_weight.float() + self.conv_transpose2d_bias = conv_transpose2d_bias.float() + + def forward(self, x): + linear_res = F.linear( + x.float(), + self.linear_weight, + self.linear_bias, + ) + conv2d_res = F.conv2d( + input=linear_res.unsqueeze(dim=0).float(), + weight=self.conv2d_weight, + bias=self.conv2d_bias, + ) + return F.conv_transpose2d( + input=conv2d_res, + weight=self.conv_transpose2d_weight, + bias=self.conv_transpose2d_bias, + ) + + minibatch = 1 + in_channels = 6 + iH = 4 + iW = 5 + out_channels = 6 + kH = 2 + kW = 3 + + self.check_replacement( + model=torch.jit.script( + TestPrepackedLinearBeforeInlineAndConv2dOp( + linear_weight=torch.rand(iW, 3), + linear_bias=torch.rand(iW), + conv2d_weight=torch.rand(out_channels, in_channels, kH, kW), + conv2d_bias=torch.rand(out_channels), + conv_transpose2d_weight=torch.rand( + out_channels, + in_channels, + kH, + kW, + ), + conv_transpose2d_bias=torch.rand(out_channels), + ), + ), + replacements={ + "prepacked::linear_clamp_prepack": "prim::CallFunction", + "prepacked::linear_clamp_run": "prim::CallFunction", + "prepacked::conv2d_clamp_prepack": "aten::conv2d", + "prepacked::conv2d_clamp_run": "aten::conv2d", + "prepacked::conv2d_transpose_clamp_prepack": + "aten::conv_transpose2d", + "prepacked::conv2d_transpose_clamp_run": + "aten::conv_transpose2d", + }, + jit_pass=torch._C._jit_pass_insert_prepacked_ops, + ) + + def test_insert_pre_packed_linear_op(self): + self.check_replacement( + model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)), + replacements={ + "prepacked::linear_clamp_prepack": "aten::linear", + "prepacked::linear_clamp_run": "aten::linear" + }, + jit_pass=torch._C._jit_pass_insert_prepacked_ops, + ) + + def run_test_fuse_activation_with_pack_ops_linear_conv2d( + self, + linear_activation, + linear_activation_kind, + conv2d_activation, + conv2d_activation_kind, + ): + class TestFuseActivationLinearConv2d(torch.nn.Module): + def __init__( + self, + linear_weight, + linear_bias, + conv2d_weight, + conv2d_bias, + ): + super(TestFuseActivationLinearConv2d, self).__init__() + self.linear_weight = linear_weight + self.linear_bias = linear_bias + self.conv2d_weight = conv2d_weight + self.conv2d_bias = conv2d_bias + + def forward(self, x): + x = F.linear( + input=x, + weight=self.linear_weight, + bias=self.linear_bias, + ) + x = linear_activation(x) + x = F.conv2d( + input=x.unsqueeze(dim=0), + weight=self.conv2d_weight, + bias=self.conv2d_bias, + ) + return conv2d_activation(x) + + linear_in_features = 5 + linear_out_features = 4 + conv2d_in_channels = 3 + conv2d_out_channels = 4 + conv2d_kernel = 2 + x_shape = (3, 2, 5) + + model = torch.jit.trace( + TestFuseActivationLinearConv2d( + linear_weight=torch.nn.Parameter( + data=torch.rand( + linear_out_features, + linear_in_features, + ), + requires_grad=False, + ), + linear_bias=torch.nn.Parameter( + data=torch.rand(linear_out_features), + requires_grad=False, + ), + conv2d_weight=torch.rand( + conv2d_out_channels, + conv2d_in_channels, + conv2d_kernel, + conv2d_kernel, + ), + conv2d_bias=torch.rand(conv2d_out_channels), + ), + torch.rand(x_shape), + ) + + torch._C._jit_pass_insert_prepacked_ops(model._c) + + self.check_replacement( + model=model, + replacements={ + "prepacked::linear_clamp_prepack": + "prepacked::linear_clamp_prepack", + "prepacked::linear_clamp_run": linear_activation_kind, + "prepacked::conv2d_clamp_prepack": + "prepacked::conv2d_clamp_prepack", + "prepacked::conv2d_clamp_run": conv2d_activation_kind, + }, + jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv, + ) + + def test_fuse_activation_with_pack_ops_linear_conv2d_1(self): + self.run_test_fuse_activation_with_pack_ops_linear_conv2d( + linear_activation=F.hardtanh, + linear_activation_kind="aten::hardtanh", + conv2d_activation=F.hardtanh_, + conv2d_activation_kind="aten::hardtanh_" + ) + + def test_fuse_activation_with_pack_ops_linear_conv2d_2(self): + self.run_test_fuse_activation_with_pack_ops_linear_conv2d( + linear_activation=F.hardtanh_, + linear_activation_kind="aten::hardtanh_", + conv2d_activation=F.hardtanh, + conv2d_activation_kind="aten::hardtanh" + ) + + def test_fuse_activation_with_pack_ops_linear_conv2d_3(self): + self.run_test_fuse_activation_with_pack_ops_linear_conv2d( + linear_activation=F.relu, + linear_activation_kind="aten::relu", + conv2d_activation=F.relu_, + conv2d_activation_kind="aten::relu_" + ) + + def test_fuse_activation_with_pack_ops_linear_conv2d_4(self): + self.run_test_fuse_activation_with_pack_ops_linear_conv2d( + linear_activation=F.relu_, + linear_activation_kind="aten::relu_", + conv2d_activation=F.relu, + conv2d_activation_kind="aten::relu" + ) diff --git a/test/test_jit.py b/test/test_jit.py index e94ed8d..8d1981d 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -61,6 +61,7 @@ from jit.test_convert_activation import TestFunctionalToInplaceActivation, TestI from jit.test_parametrization import TestParametrization # noqa: F401 from jit.test_attr import TestGetDefaultAttr # noqa: F401 from jit.test_aten_pow import TestAtenPow # noqa: F401 +from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401 # Torch from torch import Tensor diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp index 11210a4..9b2cac6 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -26,8 +26,8 @@ namespace { void replaceConv1dWithConv2d(std::shared_ptr& graph) { std::string conv_1d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): - %r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) - return (%r) )"; + %res = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) + return (%res) )"; std::string conv_2d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): @@ -47,8 +47,24 @@ void replaceConv1dWithConv2d(std::shared_ptr& graph) { %output : Tensor = aten::squeeze(%output_2d, %two) return (%output) )"; + std::vector> value_mappings( + {{"zero", "res"}, + {"one", "res"}, + {"stride_w", "res"}, + {"stride_2d", "res"}, + {"padding_w", "res"}, + {"padding_2d", "res"}, + {"dilation_w", "res"}, + {"dilation_2d", "res"}, + {"two", "res"}, + {"input_2d", "res"}, + {"weight_2d", "res"}, + {"output_2d", "res"}, + {"output", "res"}}); + SubgraphRewriter rewriter; - rewriter.RegisterRewritePattern(conv_1d_pattern, conv_2d_pattern); + rewriter.RegisterRewritePattern( + conv_1d_pattern, conv_2d_pattern, value_mappings); rewriter.runOnGraph(graph); } @@ -80,8 +96,8 @@ void insertPrePackedLinearOp(std::shared_ptr& graph) { std::string linear_before_inline = R"( graph(%linear, %input, %weight, %bias): - %r = prim::CallFunction(%linear, %input, %weight, %bias) - return (%r))"; + %res = prim::CallFunction(%linear, %input, %weight, %bias) + return (%res))"; std::string prepacked_ops_pattern_before_inline = R"( graph(%linear, %input, %weight, %bias): %output_min_max : None = prim::Constant() @@ -91,8 +107,8 @@ void insertPrePackedLinearOp(std::shared_ptr& graph) { return (%res))"; std::string linear_pattern = R"( graph(%input, %weight, %bias): - %r = aten::linear(%input, %weight, %bias) - return (%r))"; + %res = aten::linear(%input, %weight, %bias) + return (%res))"; std::string prepacked_ops_pattern = R"( graph(%input, %weight, %bias): %output_min_max : None = prim::Constant() @@ -112,13 +128,24 @@ void insertPrePackedLinearOp(std::shared_ptr& graph) { return false; }; + std::vector> value_mappings( + {{"output_min_max", "res"}, + {"packed_weight_bias", "res"}, + {"res", "res"}}); + SubgraphRewriter linear_call_fn_rewriter; linear_call_fn_rewriter.RegisterRewritePattern( - linear_before_inline, prepacked_ops_pattern_before_inline); + linear_before_inline, + prepacked_ops_pattern_before_inline, + value_mappings); linear_call_fn_rewriter.runOnGraph(graph, filter); + value_mappings = { + {"output_min_max", "res"}, {"packed_weight_bias", "res"}, {"res", "res"}}; + SubgraphRewriter linear_rewriter; - linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern); + linear_rewriter.RegisterRewritePattern( + linear_pattern, prepacked_ops_pattern, value_mappings); linear_rewriter.runOnGraph(graph); } @@ -128,8 +155,8 @@ void insertPrePackedConv2dOp(std::shared_ptr& graph) { std::string conv_2d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): - %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) - return (%r) )"; + %res = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) + return (%res) )"; std::string prepacked_ops_conv2d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): @@ -137,19 +164,24 @@ void insertPrePackedConv2dOp(std::shared_ptr& graph) { %packed_weight_bias = prepacked::conv2d_clamp_prepack( %weight, %bias, %stride, %padding, %dilation, %groups, %output_min_max, %output_min_max) - %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - return (%r) )"; + %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) + return (%res) )"; + + std::vector> value_mappings( + {{"output_min_max", "res"}, + {"packed_weight_bias", "res"}, + {"res", "res"}}); SubgraphRewriter rewriter; rewriter.RegisterRewritePattern( - conv_2d_pattern, prepacked_ops_conv2d_pattern); + conv_2d_pattern, prepacked_ops_conv2d_pattern, value_mappings); rewriter.runOnGraph(graph); std::string conv_2d_transpose_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int): - %r = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation) - return (%r) )"; + %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation) + return (%res) )"; std::string prepacked_ops_conv2d_transpose_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int): @@ -157,12 +189,17 @@ void insertPrePackedConv2dOp(std::shared_ptr& graph) { %packed_weight_bias = prepacked::conv2d_transpose_clamp_prepack( %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups, %output_min_max, %output_min_max) - %r = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias) - return (%r) )"; + %res = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias) + return (%res) )"; + + value_mappings = { + {"output_min_max", "res"}, {"packed_weight_bias", "res"}, {"res", "res"}}; SubgraphRewriter transpose_rewriter; transpose_rewriter.RegisterRewritePattern( - conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern); + conv_2d_transpose_pattern, + prepacked_ops_conv2d_transpose_pattern, + value_mappings); transpose_rewriter.runOnGraph(graph); } @@ -182,8 +219,8 @@ void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack( %weight, %bias, %stride, %padding, %dilation, %groups, %output_min, %output_max) - %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - return (%r) )"; + %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) + return (%res) )"; std::string linear_prepack_run_hardtanh = R"( graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max): @@ -193,8 +230,13 @@ void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { %res = aten::hardtanh(%linear_res, %output_min, %output_max) return (%res))"; + std::vector> value_mappings( + {{"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}}); + rewriter.RegisterRewritePattern( - linear_prepack_run_hardtanh, linear_prepack_run_hardtanh_fused); + linear_prepack_run_hardtanh, + linear_prepack_run_hardtanh_fused, + value_mappings); std::string conv2d_prepack_run_hardtanh = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], @@ -203,11 +245,16 @@ void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - %r = aten::hardtanh(%conv2d_res, %output_min, %output_max) - return (%r) )"; + %res = aten::hardtanh(%conv2d_res, %output_min, %output_max) + return (%res) )"; + + value_mappings = { + {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}}; rewriter.RegisterRewritePattern( - conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused); + conv2d_prepack_run_hardtanh, + conv2d_prepack_run_hardtanh_fused, + value_mappings); std::string linear_prepack_run_hardtanh_inplace = R"( graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max): @@ -224,13 +271,24 @@ void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max) - return (%r) )"; + %res = aten::hardtanh_(%conv2d_res, %output_min, %output_max) + return (%res) )"; + + value_mappings = { + {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}}; rewriter.RegisterRewritePattern( - linear_prepack_run_hardtanh_inplace, linear_prepack_run_hardtanh_fused); + linear_prepack_run_hardtanh_inplace, + linear_prepack_run_hardtanh_fused, + value_mappings); + + value_mappings = { + {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}}; + rewriter.RegisterRewritePattern( - conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused); + conv2d_prepack_run_hardtanh_inplace, + conv2d_prepack_run_hardtanh_fused, + value_mappings); rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); } @@ -255,8 +313,8 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack( %weight, %bias, %stride, %padding, %dilation, %groups, %output_min, %output_max) - %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - return (%r) )"; + %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) + return (%res) )"; std::string linear_prepack_run_relu = R"( graph(%input, %weight, %bias, %dummy_min_max): @@ -266,8 +324,14 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { %res = aten::relu(%linear_res) return (%res))"; + std::vector> value_mappings( + {{"output_min", "packed_weight_bias"}, + {"output_max", "packed_weight_bias"}, + {"packed_weight_bias", "packed_weight_bias"}, + {"res", "res"}}); + rewriter.RegisterRewritePattern( - linear_prepack_run_relu, linear_prepack_run_relu_fused); + linear_prepack_run_relu, linear_prepack_run_relu_fused, value_mappings); std::string conv2d_prepack_run_relu = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], @@ -276,11 +340,17 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - %r = aten::relu(%conv2d_res) - return (%r) )"; + %res = aten::relu(%conv2d_res) + return (%res) )"; + + value_mappings = { + {"output_min", "packed_weight_bias"}, + {"output_max", "packed_weight_bias"}, + {"packed_weight_bias", "packed_weight_bias"}, + {"res", "res"}}; rewriter.RegisterRewritePattern( - conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused); + conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused, value_mappings); std::string linear_prepack_run_relu_inplace = R"( graph(%input, %weight, %bias, %dummy_min_max): @@ -297,13 +367,30 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - %r = aten::relu_(%conv2d_res) - return (%r) )"; + %res = aten::relu_(%conv2d_res) + return (%res) )"; + + value_mappings = { + {"output_min", "packed_weight_bias"}, + {"output_max", "packed_weight_bias"}, + {"packed_weight_bias", "packed_weight_bias"}, + {"res", "res"}}; rewriter.RegisterRewritePattern( - linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused); + linear_prepack_run_relu_inplace, + linear_prepack_run_relu_fused, + value_mappings); + + value_mappings = { + {"output_min", "packed_weight_bias"}, + {"output_max", "packed_weight_bias"}, + {"packed_weight_bias", "packed_weight_bias"}, + {"res", "res"}}; + rewriter.RegisterRewritePattern( - conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused); + conv2d_prepack_run_relu_inplace, + conv2d_prepack_run_relu_fused, + value_mappings); rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); } -- 2.7.4