[pruner] modify base pruner to prune bias by default (#63202)
authorKaren Zhou <kazhou@fb.com>
Tue, 24 Aug 2021 17:17:28 +0000 (10:17 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 17:25:45 +0000 (10:25 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63202

By default, the prune will also prune biases, such that the whole output channel is removed. The user can manually set `also_prune_bias` to False when calling `prepare` if they don't want the bias to be pruned.
ghstack-source-id: 136466671

Test Plan:
`buck test mode/dev-nosan //caffe2/test:ao -- TestBasePruner`

https://pxl.cl/1MV32

modify `fusion_tests` according to API change
`buck test mode/opt //scripts/kazhou:fusion_tests`

https://pxl.cl/1NbKz

Reviewed By: z-a-f

Differential Revision: D30294494

fbshipit-source-id: c84655648bee0035559195ca855b98fb7edaa134

torch/ao/sparsity/__init__.py
torch/ao/sparsity/experimental/pruner/base_pruner.py
torch/ao/sparsity/experimental/pruner/parametrization.py

index 55b8d70..06854a4 100644 (file)
@@ -20,6 +20,7 @@ from .sparsifier.utils import fqn_to_module
 # Parametrizations
 from .experimental.pruner.parametrization import PruningParametrization
 from .experimental.pruner.parametrization import ActivationReconstruction
+from .experimental.pruner.parametrization import BiasHook
 
 # Pruner
 from .experimental.pruner.base_pruner import BasePruner
index d89b3cc..a8a7b69 100644 (file)
@@ -7,7 +7,7 @@ from torch.nn.utils import parametrize
 
 from torch.nn.modules.container import ModuleDict, ModuleList
 
-from .parametrization import PruningParametrization, ActivationReconstruction
+from .parametrization import PruningParametrization, ActivationReconstruction, BiasHook
 
 from torch.ao.sparsity import BaseSparsifier, fqn_to_module
 
@@ -26,26 +26,16 @@ class BasePruner(BaseSparsifier):
         `module_groups`.
 
     Args:
-        - model [nn.Module]: model to configure. The model itself is not saved
-            but used for the state_dict saving / loading.
-        - config [list]: configuration elements could either be instances of
-            nn.Module or dict maps. The dicts must have a key 'module' with the
-            value being an instance of a nn.Module.
         - defaults [dict]: default configurations will be attached to the
             configuration. Only the keys that don't exist in the `config` will
             be updated.
+        - also_prune_bias [bool]: whether to prune bias in addition to weights (to prune full output channel)
+            or not; default=True.
 
     """
-    def __init__(self, defaults):
+    def __init__(self, defaults, also_prune_bias=True):
         super().__init__(defaults)
-
-    def bias_hook(self, module, input, output):
-        if getattr(module, '_bias', None) is not None:
-            idx = [1] * len(output.shape)
-            idx[1] = output.shape[1]
-            bias = module._bias.reshape(idx)
-            output += bias
-        return output
+        self.prune_bias = also_prune_bias
 
     def _prepare(self, use_path=False, *args, **kwargs):
         r"""Adds mask parametrization to the layer weight
@@ -78,7 +68,7 @@ class BasePruner(BaseSparsifier):
             if module.bias is not None:
                 module.register_parameter('_bias', nn.Parameter(module.bias.detach()))
                 module.bias = None
-            self.bias_handles.append(module.register_forward_hook(self.bias_hook))
+            self.bias_handles.append(module.register_forward_hook(BiasHook(module.parametrizations.weight[0], self.prune_bias)))
 
     def squash_mask(self, use_path=False, *args, **kwargs):
         for config in self.module_groups:
index d4bebb2..696b16e 100644 (file)
@@ -36,3 +36,25 @@ class ActivationReconstruction:
         reconstructed_tensor = torch.zeros(sizes)
         reconstructed_tensor[indices] = output
         return reconstructed_tensor
+
+
+class BiasHook:
+    def __init__(self, parametrization, prune_bias):
+        self.param = parametrization
+        self.prune_bias = prune_bias
+
+    def __call__(self, module, input, output):
+        pruned_outputs = self.param.pruned_outputs
+
+        if getattr(module, '_bias', None) is not None:
+            bias = module._bias.data
+            if self.prune_bias:
+                bias[list(pruned_outputs)] = 0
+
+            # reshape bias to broadcast over output dimensions
+            idx = [1] * len(output.shape)
+            idx[1] = -1
+            bias = bias.reshape(idx)
+
+            output += bias
+        return output