From 16ba20507a7a8fcb62d88c719eceab578e09e210 Mon Sep 17 00:00:00 2001 From: Karen Zhou Date: Tue, 24 Aug 2021 10:17:28 -0700 Subject: [PATCH] [pruner] amend base pruner API to match base sparsifier (#63178) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63178 Update base pruner API to match base sparsifier API as defined in D28970960 / PR58955 Changes include: - `enable_mask_update = True` in `__init__` - `prepare` takes model and config instead of constructor - convert functionality renamed to `squash_mask`, `convert` method call now raises Error - `activation_handles` ad `bias_handles` initialized in `_prepare` instead of constructor ghstack-source-id: 136467595 Test Plan: Function names updates according to changes `buck test mode/dev-nosan //caffe2/test:ao -- TestBasePruner` https://pxl.cl/1MTgH TODO will need to modify `fbcode/scripts/kazhou/fusion_tests.py` to use new API Reviewed By: z-a-f Differential Revision: D30287179 fbshipit-source-id: d4727bea1873b500f2d4bb784db26d532bf26cce --- test/ao/sparsity/test_pruner.py | 61 ++++++------ torch/ao/sparsity/__init__.py | 2 + .../ao/sparsity/experimental/pruner/base_pruner.py | 109 +++------------------ torch/ao/sparsity/sparsifier/base_sparsifier.py | 24 +---- torch/ao/sparsity/sparsifier/utils.py | 18 ++++ 5 files changed, 66 insertions(+), 148 deletions(-) diff --git a/test/ao/sparsity/test_pruner.py b/test/ao/sparsity/test_pruner.py index 8f5f6dd..c358df6 100644 --- a/test/ao/sparsity/test_pruner.py +++ b/test/ao/sparsity/test_pruner.py @@ -161,7 +161,7 @@ class TestBasePruner(TestCase): # Assume that this is the 1st/only parametrization assert type(module.parametrizations.weight[0]) == PruningParametrization - def _check_pruner_converted(self, model, pruner, device): + def _check_pruner_mask_squashed(self, model, pruner, device): for g in pruner.module_groups: module = g['module'] assert module.weight.device == device @@ -184,16 +184,18 @@ class TestBasePruner(TestCase): self.assertRaisesRegex(TypeError, 'with abstract methods update_mask', BasePruner) model = model.to(device) - pruner = SimplePruner(model, None, None) + pruner = SimplePruner(None) + pruner.prepare(model, None) for g in pruner.module_groups: module = g['module'] assert module.weight.device == device assert len(pruner.module_groups) == 2 pruner.step() # Can instantiate the model with configs - pruner = SimplePruner(model, [model.linear], {'test': 3}) + pruner = SimplePruner({'test': 3}) + pruner.prepare(model, [model.linear]) assert len(pruner.module_groups) == 1 - assert pruner.module_groups[0]['path'] == 'linear' + assert pruner.module_groups[0]['fqn'] == 'linear' assert 'test' in pruner.module_groups[0] assert pruner.module_groups[0]['test'] == 3 @@ -205,8 +207,8 @@ class TestBasePruner(TestCase): def _test_prepare_linear_on_device(self, model, device): model = model.to(device) x = torch.ones(128, 16) - pruner = SimplePruner(model, None, None) - pruner.prepare() + pruner = SimplePruner(None) + pruner.prepare(model, None) self._check_pruner_prepared(model, pruner, device) assert model(x).shape == (128, 16) @@ -219,8 +221,8 @@ class TestBasePruner(TestCase): def _test_prepare_conv2d_on_device(self, model, device): model = model.to(device) x = torch.ones((1, 1, 28, 28)) - pruner = SimplePruner(model, None, None) - pruner.prepare() + pruner = SimplePruner(None) + pruner.prepare(model, None) self._check_pruner_prepared(model, pruner, device) assert model(x).shape == (1, 64, 24, 24) @@ -230,51 +232,49 @@ class TestBasePruner(TestCase): for model in models: self._test_prepare_conv2d_on_device(model, torch.device(device)) - def _test_convert_linear_on_device(self, model, device): + def _test_squash_mask_linear_on_device(self, model, device): model = model.to(device) x = torch.ones(128, 16) - pruner = SimplePruner(model, None, None) - pruner.prepare() - pruner.convert() - self._check_pruner_converted(model, pruner, device) + pruner = SimplePruner(None) + pruner.prepare(model, None) + pruner.squash_mask() + self._check_pruner_mask_squashed(model, pruner, device) assert model(x).shape == (128, 16) - def test_convert_linear(self): + def test_squash_mask_linear(self): models = [Linear(), LinearB()] # without and with bias for device in DEVICES: for model in models: - self._test_convert_linear_on_device(model, torch.device(device)) + self._test_squash_mask_linear_on_device(model, torch.device(device)) - def _test_convert_conv2d_on_device(self, model, device): + def _test_squash_mask_conv2d_on_device(self, model, device): model = model.to(device) x = torch.ones((1, 1, 28, 28)) - pruner = SimplePruner(model, None, None) - pruner.prepare() - pruner.convert() - self._check_pruner_converted(model, pruner, device) + pruner = SimplePruner(None) + pruner.prepare(model, None) + pruner.squash_mask() + self._check_pruner_mask_squashed(model, pruner, device) assert model(x).shape == (1, 64, 24, 24) - def test_convert_conv2d(self): + def test_squash_mask_conv2d(self): models = [Conv2dA(), Conv2dB(), Conv2dC()] for device in DEVICES: for model in models: - self._test_convert_conv2d_on_device(model, torch.device(device)) + self._test_squash_mask_conv2d_on_device(model, torch.device(device)) def _test_step_linear_on_device(self, model, is_basic, device): model = model.to(device) if is_basic: x = torch.ones(16, 16) - pruner = SimplePruner(model, None, None) - pruner.prepare() - pruner.enable_mask_update = True + pruner = SimplePruner(None) + pruner.prepare(model, None) self._check_pruner_valid_before_step(model, pruner, device) pruner.step() self._check_pruner_valid_after_step(model, pruner, {1}, device) else: x = torch.ones(7, 7) - pruner = MultiplePruner(model, None, None) - pruner.prepare() - pruner.enable_mask_update = True + pruner = MultiplePruner(None) + pruner.prepare(model, None) self._check_pruner_valid_before_step(model, pruner, device) pruner.step() self._check_pruner_valid_after_step(model, pruner, {1, 2}, device) @@ -291,9 +291,8 @@ class TestBasePruner(TestCase): def _test_step_conv2d_on_device(self, model, device): model = model.to(device) x = torch.ones((1, 1, 28, 28)) - pruner = SimplePruner(model, None, None) - pruner.prepare() - pruner.enable_mask_update = True + pruner = SimplePruner(None) + pruner.prepare(model, None) self._check_pruner_valid_before_step(model, pruner, device) pruner.step() self._check_pruner_valid_after_step(model, pruner, {1}, device) diff --git a/torch/ao/sparsity/__init__.py b/torch/ao/sparsity/__init__.py index 9ba05f2..55b8d70 100644 --- a/torch/ao/sparsity/__init__.py +++ b/torch/ao/sparsity/__init__.py @@ -12,6 +12,8 @@ from .scheduler.lambda_scheduler import LambdaSL # Parametrizations from .sparsifier.utils import FakeSparsity +from .sparsifier.utils import module_to_fqn +from .sparsifier.utils import fqn_to_module # === Experimental === diff --git a/torch/ao/sparsity/experimental/pruner/base_pruner.py b/torch/ao/sparsity/experimental/pruner/base_pruner.py index 92e1945..d89b3cc 100644 --- a/torch/ao/sparsity/experimental/pruner/base_pruner.py +++ b/torch/ao/sparsity/experimental/pruner/base_pruner.py @@ -1,6 +1,5 @@ import abc -import copy import torch from torch import nn @@ -10,31 +9,15 @@ from torch.nn.modules.container import ModuleDict, ModuleList from .parametrization import PruningParametrization, ActivationReconstruction +from torch.ao.sparsity import BaseSparsifier, fqn_to_module + SUPPORTED_MODULES = { nn.Linear, nn.Conv2d } -def _module_to_path(model, layer, prefix=''): - for name, child in model.named_children(): - new_name = prefix + '.' + name - if child is layer: - return new_name - child_path = _module_to_path(child, layer, prefix=new_name) - if child_path is not None: - return child_path - return None - -def _path_to_module(model, path): - path = path.split('.') - for name in path: - model = getattr(model, name, None) - if model is None: - return None - return model - - -class BasePruner(abc.ABC): + +class BasePruner(BaseSparsifier): r"""Base class for all pruners. Abstract methods that need to be implemented: @@ -53,66 +36,8 @@ class BasePruner(abc.ABC): be updated. """ - def __init__(self, model, config, defaults): - super().__init__() - self.config = config - self.defaults = defaults - if self.defaults is None: - self.defaults = dict() - - self.module_groups = [] - self.enable_mask_update = False - self.activation_handles = [] - self.bias_handles = [] - - self.model = model - # If no config -- try getting all the supported layers - if self.config is None: - # Add all models to the config - self.config = [] - stack = [model] - while stack: - module = stack.pop() - for name, child in module.named_children(): - if type(child) in SUPPORTED_MODULES: - self.config.append(child) - else: - stack.append(child) - - for module_config in self.config: - if isinstance(module_config, nn.Module): - module_config = {'module': module_config} - local_args = copy.deepcopy(self.defaults) - local_args.update(module_config) - module = local_args['module'] - module_path = _module_to_path(self.model, module) - if module_path and module_path[0] == '.': - module_path = module_path[1:] - local_args['path'] = module_path - self.module_groups.append(local_args) - - def __getstate__(self): - return { - 'defaults': self.defaults, - 'module_groups': self.module_groups, - } - - def __setstate__(self, state): - self.__dict__.update(state) - - def __repr__(self): - format_string = self.__class__.__name__ + ' (' - for i, sparse_args in enumerate(self.module_groups): - module = sparse_args['module'] - format_string += '\n' - format_string += f'\tModule Group {i}\n' - format_string += f'\t module: {module}\n' - for key in sorted(sparse_args.keys()): - if key == 'module': - continue - format_string += f'\t {key}: {sparse_args[key]}\n' - format_string += ')' - return format_string + def __init__(self, defaults): + super().__init__(defaults) def bias_hook(self, module, input, output): if getattr(module, '_bias', None) is not None: @@ -122,12 +47,15 @@ class BasePruner(abc.ABC): output += bias return output - def prepare(self, use_path=False, *args, **kwargs): + def _prepare(self, use_path=False, *args, **kwargs): r"""Adds mask parametrization to the layer weight """ + self.activation_handles = [] # store removable hook handles + self.bias_handles = [] + for config in self.module_groups: if use_path: - module = _path_to_module(self.model, config['path']) + module = fqn_to_module(self.model, config['fqn']) else: module = config['module'] @@ -152,10 +80,10 @@ class BasePruner(abc.ABC): module.bias = None self.bias_handles.append(module.register_forward_hook(self.bias_hook)) - def convert(self, use_path=False, *args, **kwargs): + def squash_mask(self, use_path=False, *args, **kwargs): for config in self.module_groups: if use_path: - module = _path_to_module(self.model, config['path']) + module = fqn_to_module(self.model, config['fqn']) else: module = config['module'] parametrize.remove_parametrizations(module, 'weight', @@ -166,17 +94,6 @@ class BasePruner(abc.ABC): del module._buffers['mask'] delattr(module, 'mask') - def step(self, use_path=True): - if not self.enable_mask_update: - return - with torch.no_grad(): - for config in self.module_groups: - if use_path: - module = _path_to_module(self.model, config['path']) - else: - module = config['module'] - self.update_mask(module, **config) - @abc.abstractmethod def update_mask(self, layer, **kwargs): pass diff --git a/torch/ao/sparsity/sparsifier/base_sparsifier.py b/torch/ao/sparsity/sparsifier/base_sparsifier.py index d6bc7d7..1d01b71 100644 --- a/torch/ao/sparsity/sparsifier/base_sparsifier.py +++ b/torch/ao/sparsity/sparsifier/base_sparsifier.py @@ -8,30 +8,12 @@ import torch from torch import nn from torch.nn.utils import parametrize -from .utils import FakeSparsity +from .utils import FakeSparsity, module_to_fqn, fqn_to_module SUPPORTED_MODULES = { nn.Linear } -def _module_to_fqn(model, layer, prefix=''): - for name, child in model.named_children(): - new_name = prefix + '.' + name - if child is layer: - return new_name - child_path = _module_to_fqn(child, layer, prefix=new_name) - if child_path is not None: - return child_path - return None - -def _fqn_to_module(model, path): - path = path.split('.') - for name in path: - model = getattr(model, name, None) - if model is None: - return None - return model - class BaseSparsifier(abc.ABC): r"""Base class for all sparsifiers. @@ -136,7 +118,7 @@ class BaseSparsifier(abc.ABC): module_groups = copy.deepcopy(state_dict['module_groups']) states = state_dict['state'] for fqn, s in states.items(): - layer = _fqn_to_module(self.model, fqn) + layer = fqn_to_module(self.model, fqn) if strict and layer is None: raise RuntimeError(f'Error loading {fqn} into the model') @@ -186,7 +168,7 @@ class BaseSparsifier(abc.ABC): local_args = copy.deepcopy(self.defaults) local_args.update(module_config) module = local_args['module'] - module_fqn = _module_to_fqn(model, module) + module_fqn = module_to_fqn(model, module) if module_fqn and module_fqn[0] == '.': module_fqn = module_fqn[1:] local_args['fqn'] = module_fqn diff --git a/torch/ao/sparsity/sparsifier/utils.py b/torch/ao/sparsity/sparsifier/utils.py index 6271a8d..3124b1b 100644 --- a/torch/ao/sparsity/sparsifier/utils.py +++ b/torch/ao/sparsity/sparsifier/utils.py @@ -1,5 +1,23 @@ from torch import nn +def module_to_fqn(model, layer, prefix=''): + for name, child in model.named_children(): + new_name = prefix + '.' + name + if child is layer: + return new_name + child_path = module_to_fqn(child, layer, prefix=new_name) + if child_path is not None: + return child_path + return None + +def fqn_to_module(model, path): + path = path.split('.') + for name in path: + model = getattr(model, name, None) + if model is None: + return None + return model + # Parametrizations class FakeSparsity(nn.Module): r"""Parametrization for the weights. Should be attached to the 'weight' or -- 2.7.4