From 1256dcd50967b18c2ca335662558e77aeefe4f13 Mon Sep 17 00:00:00 2001 From: Karen Zhou Date: Tue, 24 Aug 2021 10:17:28 -0700 Subject: [PATCH] [pruner] modify base pruner to prune bias by default (#63202) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63202 By default, the prune will also prune biases, such that the whole output channel is removed. The user can manually set `also_prune_bias` to False when calling `prepare` if they don't want the bias to be pruned. ghstack-source-id: 136466671 Test Plan: `buck test mode/dev-nosan //caffe2/test:ao -- TestBasePruner` https://pxl.cl/1MV32 modify `fusion_tests` according to API change `buck test mode/opt //scripts/kazhou:fusion_tests` https://pxl.cl/1NbKz Reviewed By: z-a-f Differential Revision: D30294494 fbshipit-source-id: c84655648bee0035559195ca855b98fb7edaa134 --- torch/ao/sparsity/__init__.py | 1 + .../ao/sparsity/experimental/pruner/base_pruner.py | 22 ++++++---------------- .../experimental/pruner/parametrization.py | 22 ++++++++++++++++++++++ 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/torch/ao/sparsity/__init__.py b/torch/ao/sparsity/__init__.py index 55b8d70..06854a4 100644 --- a/torch/ao/sparsity/__init__.py +++ b/torch/ao/sparsity/__init__.py @@ -20,6 +20,7 @@ from .sparsifier.utils import fqn_to_module # Parametrizations from .experimental.pruner.parametrization import PruningParametrization from .experimental.pruner.parametrization import ActivationReconstruction +from .experimental.pruner.parametrization import BiasHook # Pruner from .experimental.pruner.base_pruner import BasePruner diff --git a/torch/ao/sparsity/experimental/pruner/base_pruner.py b/torch/ao/sparsity/experimental/pruner/base_pruner.py index d89b3cc..a8a7b69 100644 --- a/torch/ao/sparsity/experimental/pruner/base_pruner.py +++ b/torch/ao/sparsity/experimental/pruner/base_pruner.py @@ -7,7 +7,7 @@ from torch.nn.utils import parametrize from torch.nn.modules.container import ModuleDict, ModuleList -from .parametrization import PruningParametrization, ActivationReconstruction +from .parametrization import PruningParametrization, ActivationReconstruction, BiasHook from torch.ao.sparsity import BaseSparsifier, fqn_to_module @@ -26,26 +26,16 @@ class BasePruner(BaseSparsifier): `module_groups`. Args: - - model [nn.Module]: model to configure. The model itself is not saved - but used for the state_dict saving / loading. - - config [list]: configuration elements could either be instances of - nn.Module or dict maps. The dicts must have a key 'module' with the - value being an instance of a nn.Module. - defaults [dict]: default configurations will be attached to the configuration. Only the keys that don't exist in the `config` will be updated. + - also_prune_bias [bool]: whether to prune bias in addition to weights (to prune full output channel) + or not; default=True. """ - def __init__(self, defaults): + def __init__(self, defaults, also_prune_bias=True): super().__init__(defaults) - - def bias_hook(self, module, input, output): - if getattr(module, '_bias', None) is not None: - idx = [1] * len(output.shape) - idx[1] = output.shape[1] - bias = module._bias.reshape(idx) - output += bias - return output + self.prune_bias = also_prune_bias def _prepare(self, use_path=False, *args, **kwargs): r"""Adds mask parametrization to the layer weight @@ -78,7 +68,7 @@ class BasePruner(BaseSparsifier): if module.bias is not None: module.register_parameter('_bias', nn.Parameter(module.bias.detach())) module.bias = None - self.bias_handles.append(module.register_forward_hook(self.bias_hook)) + self.bias_handles.append(module.register_forward_hook(BiasHook(module.parametrizations.weight[0], self.prune_bias))) def squash_mask(self, use_path=False, *args, **kwargs): for config in self.module_groups: diff --git a/torch/ao/sparsity/experimental/pruner/parametrization.py b/torch/ao/sparsity/experimental/pruner/parametrization.py index d4bebb2..696b16e 100644 --- a/torch/ao/sparsity/experimental/pruner/parametrization.py +++ b/torch/ao/sparsity/experimental/pruner/parametrization.py @@ -36,3 +36,25 @@ class ActivationReconstruction: reconstructed_tensor = torch.zeros(sizes) reconstructed_tensor[indices] = output return reconstructed_tensor + + +class BiasHook: + def __init__(self, parametrization, prune_bias): + self.param = parametrization + self.prune_bias = prune_bias + + def __call__(self, module, input, output): + pruned_outputs = self.param.pruned_outputs + + if getattr(module, '_bias', None) is not None: + bias = module._bias.data + if self.prune_bias: + bias[list(pruned_outputs)] = 0 + + # reshape bias to broadcast over output dimensions + idx = [1] * len(output.shape) + idx[1] = -1 + bias = bias.reshape(idx) + + output += bias + return output -- 2.7.4