from torch.nn.modules.container import ModuleDict, ModuleList
-from .parametrization import PruningParametrization, ActivationReconstruction
+from .parametrization import PruningParametrization, ActivationReconstruction, BiasHook
from torch.ao.sparsity import BaseSparsifier, fqn_to_module
`module_groups`.
Args:
- - model [nn.Module]: model to configure. The model itself is not saved
- but used for the state_dict saving / loading.
- - config [list]: configuration elements could either be instances of
- nn.Module or dict maps. The dicts must have a key 'module' with the
- value being an instance of a nn.Module.
- defaults [dict]: default configurations will be attached to the
configuration. Only the keys that don't exist in the `config` will
be updated.
+ - also_prune_bias [bool]: whether to prune bias in addition to weights (to prune full output channel)
+ or not; default=True.
"""
- def __init__(self, defaults):
+ def __init__(self, defaults, also_prune_bias=True):
super().__init__(defaults)
-
- def bias_hook(self, module, input, output):
- if getattr(module, '_bias', None) is not None:
- idx = [1] * len(output.shape)
- idx[1] = output.shape[1]
- bias = module._bias.reshape(idx)
- output += bias
- return output
+ self.prune_bias = also_prune_bias
def _prepare(self, use_path=False, *args, **kwargs):
r"""Adds mask parametrization to the layer weight
if module.bias is not None:
module.register_parameter('_bias', nn.Parameter(module.bias.detach()))
module.bias = None
- self.bias_handles.append(module.register_forward_hook(self.bias_hook))
+ self.bias_handles.append(module.register_forward_hook(BiasHook(module.parametrizations.weight[0], self.prune_bias)))
def squash_mask(self, use_path=False, *args, **kwargs):
for config in self.module_groups:
reconstructed_tensor = torch.zeros(sizes)
reconstructed_tensor[indices] = output
return reconstructed_tensor
+
+
+class BiasHook:
+ def __init__(self, parametrization, prune_bias):
+ self.param = parametrization
+ self.prune_bias = prune_bias
+
+ def __call__(self, module, input, output):
+ pruned_outputs = self.param.pruned_outputs
+
+ if getattr(module, '_bias', None) is not None:
+ bias = module._bias.data
+ if self.prune_bias:
+ bias[list(pruned_outputs)] = 0
+
+ # reshape bias to broadcast over output dimensions
+ idx = [1] * len(output.shape)
+ idx[1] = -1
+ bias = bias.reshape(idx)
+
+ output += bias
+ return output