Update optimize_for_mobile to preserve node's debug information (#63106)
authorSalil Desai <salilsdesai@fb.com>
Wed, 1 Sep 2021 21:08:02 +0000 (14:08 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 21:34:20 +0000 (14:34 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63106

Propagate debug info to the re-written nodes in the graph.

Test Plan:
- Clone open source repo and build
- ``` python3 test/test_jit.py TestOptimizeForMobilePreserveDebugInfo ```
- Tests pass

Reviewed By: kimishpatel

Differential Revision: D28654659

fbshipit-source-id: 2d7c87f2fb95a3be53246375f35639bbd97c237e

test/jit/test_optimize_for_mobile_preserve_debug_info.py [new file with mode: 0644]
test/test_jit.py
torch/csrc/jit/passes/xnnpack_rewrite.cpp

diff --git a/test/jit/test_optimize_for_mobile_preserve_debug_info.py b/test/jit/test_optimize_for_mobile_preserve_debug_info.py
new file mode 100644 (file)
index 0000000..c08f3b5
--- /dev/null
@@ -0,0 +1,261 @@
+import torch
+import torch._C
+import torch.backends.xnnpack
+import torch.nn.functional as F
+from torch.testing._internal.jit_utils import JitTestCase
+
+class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
+    def check_replacement(
+        self,
+        model,
+        replacements,
+        jit_pass,
+    ):
+        """
+        model: Model which optimization is performed on
+        replacements: Dict mapping from nodes' kinds in the optimized model
+            to the kinds of nodes they replaced in the original model
+        jit_pass: Function to perform optimization
+        """
+
+        original_kinds = set(replacements.values())
+        original_source_ranges = {
+            node.kind(): node.sourceRange()
+            for node in model.graph.nodes()
+            if node.kind() in original_kinds
+        }
+
+        jit_pass(model._c)
+
+        for node in model.graph.nodes():
+            if node.kind() in replacements:
+                self.assertEqual(
+                    node.sourceRange(),
+                    original_source_ranges[replacements[node.kind()]],
+                )
+
+    def test_replace_conv1d_with_conv2d(self):
+        class TestConv1d(torch.nn.Module):
+            def __init__(self, weight, bias):
+                super(TestConv1d, self).__init__()
+                self.weight = weight
+                self.bias = bias
+
+            def forward(self, x):
+                return F.conv1d(x, self.weight, self.bias)
+
+        self.check_replacement(
+            model=torch.jit.script(
+                TestConv1d(
+                    weight=torch.rand(3, 3, 3),
+                    bias=torch.rand(3),
+                ),
+            ),
+            replacements={
+                "prim::ListUnpack": "aten::conv1d",
+                "prim::ListConstruct": "aten::conv1d",
+                "aten::unsqueeze": "aten::conv1d",
+                "aten::conv2d": "aten::conv1d",
+                "aten::squeeze": "aten::conv1d",
+            },
+            jit_pass=torch._C._jit_pass_transform_conv1d_to_conv2d,
+        )
+
+    def test_insert_pre_packed_linear_before_inline_and_conv_2d_op(self):
+        class TestPrepackedLinearBeforeInlineAndConv2dOp(torch.nn.Module):
+            def __init__(
+                self,
+                linear_weight,
+                linear_bias,
+                conv2d_weight,
+                conv2d_bias,
+                conv_transpose2d_weight,
+                conv_transpose2d_bias,
+            ):
+                super(
+                    TestPrepackedLinearBeforeInlineAndConv2dOp,
+                    self,
+                ).__init__()
+                self.linear_weight = linear_weight.float()
+                self.linear_bias = linear_bias.float()
+                self.conv2d_weight = conv2d_weight.float()
+                self.conv2d_bias = conv2d_bias.float()
+                self.conv_transpose2d_weight = conv_transpose2d_weight.float()
+                self.conv_transpose2d_bias = conv_transpose2d_bias.float()
+
+            def forward(self, x):
+                linear_res = F.linear(
+                    x.float(),
+                    self.linear_weight,
+                    self.linear_bias,
+                )
+                conv2d_res = F.conv2d(
+                    input=linear_res.unsqueeze(dim=0).float(),
+                    weight=self.conv2d_weight,
+                    bias=self.conv2d_bias,
+                )
+                return F.conv_transpose2d(
+                    input=conv2d_res,
+                    weight=self.conv_transpose2d_weight,
+                    bias=self.conv_transpose2d_bias,
+                )
+
+        minibatch = 1
+        in_channels = 6
+        iH = 4
+        iW = 5
+        out_channels = 6
+        kH = 2
+        kW = 3
+
+        self.check_replacement(
+            model=torch.jit.script(
+                TestPrepackedLinearBeforeInlineAndConv2dOp(
+                    linear_weight=torch.rand(iW, 3),
+                    linear_bias=torch.rand(iW),
+                    conv2d_weight=torch.rand(out_channels, in_channels, kH, kW),
+                    conv2d_bias=torch.rand(out_channels),
+                    conv_transpose2d_weight=torch.rand(
+                        out_channels,
+                        in_channels,
+                        kH,
+                        kW,
+                    ),
+                    conv_transpose2d_bias=torch.rand(out_channels),
+                ),
+            ),
+            replacements={
+                "prepacked::linear_clamp_prepack": "prim::CallFunction",
+                "prepacked::linear_clamp_run": "prim::CallFunction",
+                "prepacked::conv2d_clamp_prepack": "aten::conv2d",
+                "prepacked::conv2d_clamp_run": "aten::conv2d",
+                "prepacked::conv2d_transpose_clamp_prepack":
+                    "aten::conv_transpose2d",
+                "prepacked::conv2d_transpose_clamp_run":
+                    "aten::conv_transpose2d",
+            },
+            jit_pass=torch._C._jit_pass_insert_prepacked_ops,
+        )
+
+    def test_insert_pre_packed_linear_op(self):
+        self.check_replacement(
+            model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)),
+            replacements={
+                "prepacked::linear_clamp_prepack": "aten::linear",
+                "prepacked::linear_clamp_run": "aten::linear"
+            },
+            jit_pass=torch._C._jit_pass_insert_prepacked_ops,
+        )
+
+    def run_test_fuse_activation_with_pack_ops_linear_conv2d(
+        self,
+        linear_activation,
+        linear_activation_kind,
+        conv2d_activation,
+        conv2d_activation_kind,
+    ):
+        class TestFuseActivationLinearConv2d(torch.nn.Module):
+            def __init__(
+                self,
+                linear_weight,
+                linear_bias,
+                conv2d_weight,
+                conv2d_bias,
+            ):
+                super(TestFuseActivationLinearConv2d, self).__init__()
+                self.linear_weight = linear_weight
+                self.linear_bias = linear_bias
+                self.conv2d_weight = conv2d_weight
+                self.conv2d_bias = conv2d_bias
+
+            def forward(self, x):
+                x = F.linear(
+                    input=x,
+                    weight=self.linear_weight,
+                    bias=self.linear_bias,
+                )
+                x = linear_activation(x)
+                x = F.conv2d(
+                    input=x.unsqueeze(dim=0),
+                    weight=self.conv2d_weight,
+                    bias=self.conv2d_bias,
+                )
+                return conv2d_activation(x)
+
+        linear_in_features = 5
+        linear_out_features = 4
+        conv2d_in_channels = 3
+        conv2d_out_channels = 4
+        conv2d_kernel = 2
+        x_shape = (3, 2, 5)
+
+        model = torch.jit.trace(
+            TestFuseActivationLinearConv2d(
+                linear_weight=torch.nn.Parameter(
+                    data=torch.rand(
+                        linear_out_features,
+                        linear_in_features,
+                    ),
+                    requires_grad=False,
+                ),
+                linear_bias=torch.nn.Parameter(
+                    data=torch.rand(linear_out_features),
+                    requires_grad=False,
+                ),
+                conv2d_weight=torch.rand(
+                    conv2d_out_channels,
+                    conv2d_in_channels,
+                    conv2d_kernel,
+                    conv2d_kernel,
+                ),
+                conv2d_bias=torch.rand(conv2d_out_channels),
+            ),
+            torch.rand(x_shape),
+        )
+
+        torch._C._jit_pass_insert_prepacked_ops(model._c)
+
+        self.check_replacement(
+            model=model,
+            replacements={
+                "prepacked::linear_clamp_prepack":
+                    "prepacked::linear_clamp_prepack",
+                "prepacked::linear_clamp_run": linear_activation_kind,
+                "prepacked::conv2d_clamp_prepack":
+                    "prepacked::conv2d_clamp_prepack",
+                "prepacked::conv2d_clamp_run": conv2d_activation_kind,
+            },
+            jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv,
+        )
+
+    def test_fuse_activation_with_pack_ops_linear_conv2d_1(self):
+        self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
+            linear_activation=F.hardtanh,
+            linear_activation_kind="aten::hardtanh",
+            conv2d_activation=F.hardtanh_,
+            conv2d_activation_kind="aten::hardtanh_"
+        )
+
+    def test_fuse_activation_with_pack_ops_linear_conv2d_2(self):
+        self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
+            linear_activation=F.hardtanh_,
+            linear_activation_kind="aten::hardtanh_",
+            conv2d_activation=F.hardtanh,
+            conv2d_activation_kind="aten::hardtanh"
+        )
+
+    def test_fuse_activation_with_pack_ops_linear_conv2d_3(self):
+        self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
+            linear_activation=F.relu,
+            linear_activation_kind="aten::relu",
+            conv2d_activation=F.relu_,
+            conv2d_activation_kind="aten::relu_"
+        )
+
+    def test_fuse_activation_with_pack_ops_linear_conv2d_4(self):
+        self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
+            linear_activation=F.relu_,
+            linear_activation_kind="aten::relu_",
+            conv2d_activation=F.relu,
+            conv2d_activation_kind="aten::relu"
+        )
index e94ed8d..8d1981d 100644 (file)
@@ -61,6 +61,7 @@ from jit.test_convert_activation import TestFunctionalToInplaceActivation, TestI
 from jit.test_parametrization import TestParametrization  # noqa: F401
 from jit.test_attr import TestGetDefaultAttr  # noqa: F401
 from jit.test_aten_pow import TestAtenPow  # noqa: F401
+from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo  # noqa: F401
 
 # Torch
 from torch import Tensor
index 11210a4..9b2cac6 100644 (file)
@@ -26,8 +26,8 @@ namespace {
 void replaceConv1dWithConv2d(std::shared_ptr<Graph>& graph) {
   std::string conv_1d_pattern = R"(
     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
-        %r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
-        return (%r) )";
+        %res = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
+        return (%res) )";
 
   std::string conv_2d_pattern = R"(
     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
@@ -47,8 +47,24 @@ void replaceConv1dWithConv2d(std::shared_ptr<Graph>& graph) {
         %output : Tensor = aten::squeeze(%output_2d, %two)
         return (%output) )";
 
+  std::vector<std::pair<std::string, std::string>> value_mappings(
+      {{"zero", "res"},
+       {"one", "res"},
+       {"stride_w", "res"},
+       {"stride_2d", "res"},
+       {"padding_w", "res"},
+       {"padding_2d", "res"},
+       {"dilation_w", "res"},
+       {"dilation_2d", "res"},
+       {"two", "res"},
+       {"input_2d", "res"},
+       {"weight_2d", "res"},
+       {"output_2d", "res"},
+       {"output", "res"}});
+
   SubgraphRewriter rewriter;
-  rewriter.RegisterRewritePattern(conv_1d_pattern, conv_2d_pattern);
+  rewriter.RegisterRewritePattern(
+      conv_1d_pattern, conv_2d_pattern, value_mappings);
   rewriter.runOnGraph(graph);
 }
 
@@ -80,8 +96,8 @@ void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
 
   std::string linear_before_inline = R"(
     graph(%linear, %input, %weight, %bias):
-        %r = prim::CallFunction(%linear, %input, %weight, %bias)
-        return (%r))";
+        %res = prim::CallFunction(%linear, %input, %weight, %bias)
+        return (%res))";
   std::string prepacked_ops_pattern_before_inline = R"(
     graph(%linear, %input, %weight, %bias):
         %output_min_max : None = prim::Constant()
@@ -91,8 +107,8 @@ void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
         return (%res))";
   std::string linear_pattern = R"(
     graph(%input, %weight, %bias):
-        %r = aten::linear(%input, %weight, %bias)
-        return (%r))";
+        %res = aten::linear(%input, %weight, %bias)
+        return (%res))";
   std::string prepacked_ops_pattern = R"(
     graph(%input, %weight, %bias):
         %output_min_max : None = prim::Constant()
@@ -112,13 +128,24 @@ void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
     return false;
   };
 
+  std::vector<std::pair<std::string, std::string>> value_mappings(
+      {{"output_min_max", "res"},
+       {"packed_weight_bias", "res"},
+       {"res", "res"}});
+
   SubgraphRewriter linear_call_fn_rewriter;
   linear_call_fn_rewriter.RegisterRewritePattern(
-      linear_before_inline, prepacked_ops_pattern_before_inline);
+      linear_before_inline,
+      prepacked_ops_pattern_before_inline,
+      value_mappings);
   linear_call_fn_rewriter.runOnGraph(graph, filter);
 
+  value_mappings = {
+      {"output_min_max", "res"}, {"packed_weight_bias", "res"}, {"res", "res"}};
+
   SubgraphRewriter linear_rewriter;
-  linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern);
+  linear_rewriter.RegisterRewritePattern(
+      linear_pattern, prepacked_ops_pattern, value_mappings);
   linear_rewriter.runOnGraph(graph);
 }
 
@@ -128,8 +155,8 @@ void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
 
   std::string conv_2d_pattern = R"(
     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
-        %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
-        return (%r) )";
+        %res = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
+        return (%res) )";
 
   std::string prepacked_ops_conv2d_pattern = R"(
     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
@@ -137,19 +164,24 @@ void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
         %packed_weight_bias = prepacked::conv2d_clamp_prepack(
             %weight, %bias, %stride, %padding, %dilation, %groups,
             %output_min_max, %output_min_max)
-        %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
-        return (%r) )";
+        %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+        return (%res) )";
+
+  std::vector<std::pair<std::string, std::string>> value_mappings(
+      {{"output_min_max", "res"},
+       {"packed_weight_bias", "res"},
+       {"res", "res"}});
 
   SubgraphRewriter rewriter;
   rewriter.RegisterRewritePattern(
-      conv_2d_pattern, prepacked_ops_conv2d_pattern);
+      conv_2d_pattern, prepacked_ops_conv2d_pattern, value_mappings);
   rewriter.runOnGraph(graph);
 
   std::string conv_2d_transpose_pattern = R"(
       graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[],
           %output_padding:int[], %groups:int):
-        %r = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
-        return (%r) )";
+        %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
+        return (%res) )";
 
   std::string prepacked_ops_conv2d_transpose_pattern = R"(
     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int):
@@ -157,12 +189,17 @@ void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
         %packed_weight_bias = prepacked::conv2d_transpose_clamp_prepack(
             %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups,
             %output_min_max, %output_min_max)
-        %r = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias)
-        return (%r) )";
+        %res = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias)
+        return (%res) )";
+
+  value_mappings = {
+      {"output_min_max", "res"}, {"packed_weight_bias", "res"}, {"res", "res"}};
 
   SubgraphRewriter transpose_rewriter;
   transpose_rewriter.RegisterRewritePattern(
-      conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern);
+      conv_2d_transpose_pattern,
+      prepacked_ops_conv2d_transpose_pattern,
+      value_mappings);
   transpose_rewriter.runOnGraph(graph);
 }
 
@@ -182,8 +219,8 @@ void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
         %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
             %weight, %bias, %stride, %padding, %dilation, %groups,
             %output_min, %output_max)
-        %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
-        return (%r) )";
+        %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+        return (%res) )";
 
   std::string linear_prepack_run_hardtanh = R"(
     graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
@@ -193,8 +230,13 @@ void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
         %res = aten::hardtanh(%linear_res, %output_min, %output_max)
         return (%res))";
 
+  std::vector<std::pair<std::string, std::string>> value_mappings(
+      {{"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}});
+
   rewriter.RegisterRewritePattern(
-      linear_prepack_run_hardtanh, linear_prepack_run_hardtanh_fused);
+      linear_prepack_run_hardtanh,
+      linear_prepack_run_hardtanh_fused,
+      value_mappings);
 
   std::string conv2d_prepack_run_hardtanh = R"(
     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
@@ -203,11 +245,16 @@ void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
             %weight, %bias, %stride, %padding, %dilation, %groups,
             %dummy_min_max, %dummy_min_max)
         %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
-        %r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
-        return (%r) )";
+        %res = aten::hardtanh(%conv2d_res, %output_min, %output_max)
+        return (%res) )";
+
+  value_mappings = {
+      {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
 
   rewriter.RegisterRewritePattern(
-      conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
+      conv2d_prepack_run_hardtanh,
+      conv2d_prepack_run_hardtanh_fused,
+      value_mappings);
 
   std::string linear_prepack_run_hardtanh_inplace = R"(
     graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
@@ -224,13 +271,24 @@ void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
             %weight, %bias, %stride, %padding, %dilation, %groups,
             %dummy_min_max, %dummy_min_max)
         %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
-        %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
-        return (%r) )";
+        %res = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
+        return (%res) )";
+
+  value_mappings = {
+      {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
 
   rewriter.RegisterRewritePattern(
-      linear_prepack_run_hardtanh_inplace, linear_prepack_run_hardtanh_fused);
+      linear_prepack_run_hardtanh_inplace,
+      linear_prepack_run_hardtanh_fused,
+      value_mappings);
+
+  value_mappings = {
+      {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
+
   rewriter.RegisterRewritePattern(
-      conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
+      conv2d_prepack_run_hardtanh_inplace,
+      conv2d_prepack_run_hardtanh_fused,
+      value_mappings);
 
   rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
 }
@@ -255,8 +313,8 @@ void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
         %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
             %weight, %bias, %stride, %padding, %dilation, %groups,
             %output_min, %output_max)
-        %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
-        return (%r) )";
+        %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+        return (%res) )";
 
   std::string linear_prepack_run_relu = R"(
     graph(%input, %weight, %bias, %dummy_min_max):
@@ -266,8 +324,14 @@ void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
         %res = aten::relu(%linear_res)
         return (%res))";
 
+  std::vector<std::pair<std::string, std::string>> value_mappings(
+      {{"output_min", "packed_weight_bias"},
+       {"output_max", "packed_weight_bias"},
+       {"packed_weight_bias", "packed_weight_bias"},
+       {"res", "res"}});
+
   rewriter.RegisterRewritePattern(
-      linear_prepack_run_relu, linear_prepack_run_relu_fused);
+      linear_prepack_run_relu, linear_prepack_run_relu_fused, value_mappings);
 
   std::string conv2d_prepack_run_relu = R"(
     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
@@ -276,11 +340,17 @@ void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
             %weight, %bias, %stride, %padding, %dilation, %groups,
             %dummy_min_max, %dummy_min_max)
         %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
-        %r = aten::relu(%conv2d_res)
-        return (%r) )";
+        %res = aten::relu(%conv2d_res)
+        return (%res) )";
+
+  value_mappings = {
+      {"output_min", "packed_weight_bias"},
+      {"output_max", "packed_weight_bias"},
+      {"packed_weight_bias", "packed_weight_bias"},
+      {"res", "res"}};
 
   rewriter.RegisterRewritePattern(
-      conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
+      conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused, value_mappings);
 
   std::string linear_prepack_run_relu_inplace = R"(
     graph(%input, %weight, %bias, %dummy_min_max):
@@ -297,13 +367,30 @@ void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
             %weight, %bias, %stride, %padding, %dilation, %groups,
             %dummy_min_max, %dummy_min_max)
         %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
-        %r = aten::relu_(%conv2d_res)
-        return (%r) )";
+        %res = aten::relu_(%conv2d_res)
+        return (%res) )";
+
+  value_mappings = {
+      {"output_min", "packed_weight_bias"},
+      {"output_max", "packed_weight_bias"},
+      {"packed_weight_bias", "packed_weight_bias"},
+      {"res", "res"}};
 
   rewriter.RegisterRewritePattern(
-      linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused);
+      linear_prepack_run_relu_inplace,
+      linear_prepack_run_relu_fused,
+      value_mappings);
+
+  value_mappings = {
+      {"output_min", "packed_weight_bias"},
+      {"output_max", "packed_weight_bias"},
+      {"packed_weight_bias", "packed_weight_bias"},
+      {"res", "res"}};
+
   rewriter.RegisterRewritePattern(
-      conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
+      conv2d_prepack_run_relu_inplace,
+      conv2d_prepack_run_relu_fused,
+      value_mappings);
   rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
 }