# Assume that this is the 1st/only parametrization
assert type(module.parametrizations.weight[0]) == PruningParametrization
- def _check_pruner_converted(self, model, pruner, device):
+ def _check_pruner_mask_squashed(self, model, pruner, device):
for g in pruner.module_groups:
module = g['module']
assert module.weight.device == device
self.assertRaisesRegex(TypeError, 'with abstract methods update_mask',
BasePruner)
model = model.to(device)
- pruner = SimplePruner(model, None, None)
+ pruner = SimplePruner(None)
+ pruner.prepare(model, None)
for g in pruner.module_groups:
module = g['module']
assert module.weight.device == device
assert len(pruner.module_groups) == 2
pruner.step()
# Can instantiate the model with configs
- pruner = SimplePruner(model, [model.linear], {'test': 3})
+ pruner = SimplePruner({'test': 3})
+ pruner.prepare(model, [model.linear])
assert len(pruner.module_groups) == 1
- assert pruner.module_groups[0]['path'] == 'linear'
+ assert pruner.module_groups[0]['fqn'] == 'linear'
assert 'test' in pruner.module_groups[0]
assert pruner.module_groups[0]['test'] == 3
def _test_prepare_linear_on_device(self, model, device):
model = model.to(device)
x = torch.ones(128, 16)
- pruner = SimplePruner(model, None, None)
- pruner.prepare()
+ pruner = SimplePruner(None)
+ pruner.prepare(model, None)
self._check_pruner_prepared(model, pruner, device)
assert model(x).shape == (128, 16)
def _test_prepare_conv2d_on_device(self, model, device):
model = model.to(device)
x = torch.ones((1, 1, 28, 28))
- pruner = SimplePruner(model, None, None)
- pruner.prepare()
+ pruner = SimplePruner(None)
+ pruner.prepare(model, None)
self._check_pruner_prepared(model, pruner, device)
assert model(x).shape == (1, 64, 24, 24)
for model in models:
self._test_prepare_conv2d_on_device(model, torch.device(device))
- def _test_convert_linear_on_device(self, model, device):
+ def _test_squash_mask_linear_on_device(self, model, device):
model = model.to(device)
x = torch.ones(128, 16)
- pruner = SimplePruner(model, None, None)
- pruner.prepare()
- pruner.convert()
- self._check_pruner_converted(model, pruner, device)
+ pruner = SimplePruner(None)
+ pruner.prepare(model, None)
+ pruner.squash_mask()
+ self._check_pruner_mask_squashed(model, pruner, device)
assert model(x).shape == (128, 16)
- def test_convert_linear(self):
+ def test_squash_mask_linear(self):
models = [Linear(), LinearB()] # without and with bias
for device in DEVICES:
for model in models:
- self._test_convert_linear_on_device(model, torch.device(device))
+ self._test_squash_mask_linear_on_device(model, torch.device(device))
- def _test_convert_conv2d_on_device(self, model, device):
+ def _test_squash_mask_conv2d_on_device(self, model, device):
model = model.to(device)
x = torch.ones((1, 1, 28, 28))
- pruner = SimplePruner(model, None, None)
- pruner.prepare()
- pruner.convert()
- self._check_pruner_converted(model, pruner, device)
+ pruner = SimplePruner(None)
+ pruner.prepare(model, None)
+ pruner.squash_mask()
+ self._check_pruner_mask_squashed(model, pruner, device)
assert model(x).shape == (1, 64, 24, 24)
- def test_convert_conv2d(self):
+ def test_squash_mask_conv2d(self):
models = [Conv2dA(), Conv2dB(), Conv2dC()]
for device in DEVICES:
for model in models:
- self._test_convert_conv2d_on_device(model, torch.device(device))
+ self._test_squash_mask_conv2d_on_device(model, torch.device(device))
def _test_step_linear_on_device(self, model, is_basic, device):
model = model.to(device)
if is_basic:
x = torch.ones(16, 16)
- pruner = SimplePruner(model, None, None)
- pruner.prepare()
- pruner.enable_mask_update = True
+ pruner = SimplePruner(None)
+ pruner.prepare(model, None)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1}, device)
else:
x = torch.ones(7, 7)
- pruner = MultiplePruner(model, None, None)
- pruner.prepare()
- pruner.enable_mask_update = True
+ pruner = MultiplePruner(None)
+ pruner.prepare(model, None)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1, 2}, device)
def _test_step_conv2d_on_device(self, model, device):
model = model.to(device)
x = torch.ones((1, 1, 28, 28))
- pruner = SimplePruner(model, None, None)
- pruner.prepare()
- pruner.enable_mask_update = True
+ pruner = SimplePruner(None)
+ pruner.prepare(model, None)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1}, device)
# Parametrizations
from .sparsifier.utils import FakeSparsity
+from .sparsifier.utils import module_to_fqn
+from .sparsifier.utils import fqn_to_module
# === Experimental ===
import abc
-import copy
import torch
from torch import nn
from .parametrization import PruningParametrization, ActivationReconstruction
+from torch.ao.sparsity import BaseSparsifier, fqn_to_module
+
SUPPORTED_MODULES = {
nn.Linear,
nn.Conv2d
}
-def _module_to_path(model, layer, prefix=''):
- for name, child in model.named_children():
- new_name = prefix + '.' + name
- if child is layer:
- return new_name
- child_path = _module_to_path(child, layer, prefix=new_name)
- if child_path is not None:
- return child_path
- return None
-
-def _path_to_module(model, path):
- path = path.split('.')
- for name in path:
- model = getattr(model, name, None)
- if model is None:
- return None
- return model
-
-
-class BasePruner(abc.ABC):
+
+class BasePruner(BaseSparsifier):
r"""Base class for all pruners.
Abstract methods that need to be implemented:
be updated.
"""
- def __init__(self, model, config, defaults):
- super().__init__()
- self.config = config
- self.defaults = defaults
- if self.defaults is None:
- self.defaults = dict()
-
- self.module_groups = []
- self.enable_mask_update = False
- self.activation_handles = []
- self.bias_handles = []
-
- self.model = model
- # If no config -- try getting all the supported layers
- if self.config is None:
- # Add all models to the config
- self.config = []
- stack = [model]
- while stack:
- module = stack.pop()
- for name, child in module.named_children():
- if type(child) in SUPPORTED_MODULES:
- self.config.append(child)
- else:
- stack.append(child)
-
- for module_config in self.config:
- if isinstance(module_config, nn.Module):
- module_config = {'module': module_config}
- local_args = copy.deepcopy(self.defaults)
- local_args.update(module_config)
- module = local_args['module']
- module_path = _module_to_path(self.model, module)
- if module_path and module_path[0] == '.':
- module_path = module_path[1:]
- local_args['path'] = module_path
- self.module_groups.append(local_args)
-
- def __getstate__(self):
- return {
- 'defaults': self.defaults,
- 'module_groups': self.module_groups,
- }
-
- def __setstate__(self, state):
- self.__dict__.update(state)
-
- def __repr__(self):
- format_string = self.__class__.__name__ + ' ('
- for i, sparse_args in enumerate(self.module_groups):
- module = sparse_args['module']
- format_string += '\n'
- format_string += f'\tModule Group {i}\n'
- format_string += f'\t module: {module}\n'
- for key in sorted(sparse_args.keys()):
- if key == 'module':
- continue
- format_string += f'\t {key}: {sparse_args[key]}\n'
- format_string += ')'
- return format_string
+ def __init__(self, defaults):
+ super().__init__(defaults)
def bias_hook(self, module, input, output):
if getattr(module, '_bias', None) is not None:
output += bias
return output
- def prepare(self, use_path=False, *args, **kwargs):
+ def _prepare(self, use_path=False, *args, **kwargs):
r"""Adds mask parametrization to the layer weight
"""
+ self.activation_handles = [] # store removable hook handles
+ self.bias_handles = []
+
for config in self.module_groups:
if use_path:
- module = _path_to_module(self.model, config['path'])
+ module = fqn_to_module(self.model, config['fqn'])
else:
module = config['module']
module.bias = None
self.bias_handles.append(module.register_forward_hook(self.bias_hook))
- def convert(self, use_path=False, *args, **kwargs):
+ def squash_mask(self, use_path=False, *args, **kwargs):
for config in self.module_groups:
if use_path:
- module = _path_to_module(self.model, config['path'])
+ module = fqn_to_module(self.model, config['fqn'])
else:
module = config['module']
parametrize.remove_parametrizations(module, 'weight',
del module._buffers['mask']
delattr(module, 'mask')
- def step(self, use_path=True):
- if not self.enable_mask_update:
- return
- with torch.no_grad():
- for config in self.module_groups:
- if use_path:
- module = _path_to_module(self.model, config['path'])
- else:
- module = config['module']
- self.update_mask(module, **config)
-
@abc.abstractmethod
def update_mask(self, layer, **kwargs):
pass
from torch import nn
from torch.nn.utils import parametrize
-from .utils import FakeSparsity
+from .utils import FakeSparsity, module_to_fqn, fqn_to_module
SUPPORTED_MODULES = {
nn.Linear
}
-def _module_to_fqn(model, layer, prefix=''):
- for name, child in model.named_children():
- new_name = prefix + '.' + name
- if child is layer:
- return new_name
- child_path = _module_to_fqn(child, layer, prefix=new_name)
- if child_path is not None:
- return child_path
- return None
-
-def _fqn_to_module(model, path):
- path = path.split('.')
- for name in path:
- model = getattr(model, name, None)
- if model is None:
- return None
- return model
-
class BaseSparsifier(abc.ABC):
r"""Base class for all sparsifiers.
module_groups = copy.deepcopy(state_dict['module_groups'])
states = state_dict['state']
for fqn, s in states.items():
- layer = _fqn_to_module(self.model, fqn)
+ layer = fqn_to_module(self.model, fqn)
if strict and layer is None:
raise RuntimeError(f'Error loading {fqn} into the model')
local_args = copy.deepcopy(self.defaults)
local_args.update(module_config)
module = local_args['module']
- module_fqn = _module_to_fqn(model, module)
+ module_fqn = module_to_fqn(model, module)
if module_fqn and module_fqn[0] == '.':
module_fqn = module_fqn[1:]
local_args['fqn'] = module_fqn
from torch import nn
+def module_to_fqn(model, layer, prefix=''):
+ for name, child in model.named_children():
+ new_name = prefix + '.' + name
+ if child is layer:
+ return new_name
+ child_path = module_to_fqn(child, layer, prefix=new_name)
+ if child_path is not None:
+ return child_path
+ return None
+
+def fqn_to_module(model, path):
+ path = path.split('.')
+ for name in path:
+ model = getattr(model, name, None)
+ if model is None:
+ return None
+ return model
+
# Parametrizations
class FakeSparsity(nn.Module):
r"""Parametrization for the weights. Should be attached to the 'weight' or