[pruner] amend base pruner API to match base sparsifier (#63178)
authorKaren Zhou <kazhou@fb.com>
Tue, 24 Aug 2021 17:17:28 +0000 (10:17 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 17:25:43 +0000 (10:25 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63178

Update base pruner API to match base sparsifier API as defined in D28970960 / PR58955

Changes include:
- `enable_mask_update = True` in `__init__`
- `prepare` takes model and config instead of constructor
- convert functionality renamed to `squash_mask`, `convert` method call now raises Error
- `activation_handles` ad `bias_handles` initialized in `_prepare` instead of constructor
ghstack-source-id: 136467595

Test Plan:
Function names updates according to changes

`buck test mode/dev-nosan //caffe2/test:ao -- TestBasePruner`

https://pxl.cl/1MTgH

TODO will need to modify `fbcode/scripts/kazhou/fusion_tests.py` to use new API

Reviewed By: z-a-f

Differential Revision: D30287179

fbshipit-source-id: d4727bea1873b500f2d4bb784db26d532bf26cce

test/ao/sparsity/test_pruner.py
torch/ao/sparsity/__init__.py
torch/ao/sparsity/experimental/pruner/base_pruner.py
torch/ao/sparsity/sparsifier/base_sparsifier.py
torch/ao/sparsity/sparsifier/utils.py

index 8f5f6dd..c358df6 100644 (file)
@@ -161,7 +161,7 @@ class TestBasePruner(TestCase):
             # Assume that this is the 1st/only parametrization
             assert type(module.parametrizations.weight[0]) == PruningParametrization
 
-    def _check_pruner_converted(self, model, pruner, device):
+    def _check_pruner_mask_squashed(self, model, pruner, device):
         for g in pruner.module_groups:
             module = g['module']
             assert module.weight.device == device
@@ -184,16 +184,18 @@ class TestBasePruner(TestCase):
         self.assertRaisesRegex(TypeError, 'with abstract methods update_mask',
                                BasePruner)
         model = model.to(device)
-        pruner = SimplePruner(model, None, None)
+        pruner = SimplePruner(None)
+        pruner.prepare(model, None)
         for g in pruner.module_groups:
             module = g['module']
             assert module.weight.device == device
         assert len(pruner.module_groups) == 2
         pruner.step()
         # Can instantiate the model with configs
-        pruner = SimplePruner(model, [model.linear], {'test': 3})
+        pruner = SimplePruner({'test': 3})
+        pruner.prepare(model, [model.linear])
         assert len(pruner.module_groups) == 1
-        assert pruner.module_groups[0]['path'] == 'linear'
+        assert pruner.module_groups[0]['fqn'] == 'linear'
         assert 'test' in pruner.module_groups[0]
         assert pruner.module_groups[0]['test'] == 3
 
@@ -205,8 +207,8 @@ class TestBasePruner(TestCase):
     def _test_prepare_linear_on_device(self, model, device):
         model = model.to(device)
         x = torch.ones(128, 16)
-        pruner = SimplePruner(model, None, None)
-        pruner.prepare()
+        pruner = SimplePruner(None)
+        pruner.prepare(model, None)
         self._check_pruner_prepared(model, pruner, device)
         assert model(x).shape == (128, 16)
 
@@ -219,8 +221,8 @@ class TestBasePruner(TestCase):
     def _test_prepare_conv2d_on_device(self, model, device):
         model = model.to(device)
         x = torch.ones((1, 1, 28, 28))
-        pruner = SimplePruner(model, None, None)
-        pruner.prepare()
+        pruner = SimplePruner(None)
+        pruner.prepare(model, None)
         self._check_pruner_prepared(model, pruner, device)
         assert model(x).shape == (1, 64, 24, 24)
 
@@ -230,51 +232,49 @@ class TestBasePruner(TestCase):
             for model in models:
                 self._test_prepare_conv2d_on_device(model, torch.device(device))
 
-    def _test_convert_linear_on_device(self, model, device):
+    def _test_squash_mask_linear_on_device(self, model, device):
         model = model.to(device)
         x = torch.ones(128, 16)
-        pruner = SimplePruner(model, None, None)
-        pruner.prepare()
-        pruner.convert()
-        self._check_pruner_converted(model, pruner, device)
+        pruner = SimplePruner(None)
+        pruner.prepare(model, None)
+        pruner.squash_mask()
+        self._check_pruner_mask_squashed(model, pruner, device)
         assert model(x).shape == (128, 16)
 
-    def test_convert_linear(self):
+    def test_squash_mask_linear(self):
         models = [Linear(), LinearB()]  # without and with bias
         for device in DEVICES:
             for model in models:
-                self._test_convert_linear_on_device(model, torch.device(device))
+                self._test_squash_mask_linear_on_device(model, torch.device(device))
 
-    def _test_convert_conv2d_on_device(self, model, device):
+    def _test_squash_mask_conv2d_on_device(self, model, device):
         model = model.to(device)
         x = torch.ones((1, 1, 28, 28))
-        pruner = SimplePruner(model, None, None)
-        pruner.prepare()
-        pruner.convert()
-        self._check_pruner_converted(model, pruner, device)
+        pruner = SimplePruner(None)
+        pruner.prepare(model, None)
+        pruner.squash_mask()
+        self._check_pruner_mask_squashed(model, pruner, device)
         assert model(x).shape == (1, 64, 24, 24)
 
-    def test_convert_conv2d(self):
+    def test_squash_mask_conv2d(self):
         models = [Conv2dA(), Conv2dB(), Conv2dC()]
         for device in DEVICES:
             for model in models:
-                self._test_convert_conv2d_on_device(model, torch.device(device))
+                self._test_squash_mask_conv2d_on_device(model, torch.device(device))
 
     def _test_step_linear_on_device(self, model, is_basic, device):
         model = model.to(device)
         if is_basic:
             x = torch.ones(16, 16)
-            pruner = SimplePruner(model, None, None)
-            pruner.prepare()
-            pruner.enable_mask_update = True
+            pruner = SimplePruner(None)
+            pruner.prepare(model, None)
             self._check_pruner_valid_before_step(model, pruner, device)
             pruner.step()
             self._check_pruner_valid_after_step(model, pruner, {1}, device)
         else:
             x = torch.ones(7, 7)
-            pruner = MultiplePruner(model, None, None)
-            pruner.prepare()
-            pruner.enable_mask_update = True
+            pruner = MultiplePruner(None)
+            pruner.prepare(model, None)
             self._check_pruner_valid_before_step(model, pruner, device)
             pruner.step()
             self._check_pruner_valid_after_step(model, pruner, {1, 2}, device)
@@ -291,9 +291,8 @@ class TestBasePruner(TestCase):
     def _test_step_conv2d_on_device(self, model, device):
         model = model.to(device)
         x = torch.ones((1, 1, 28, 28))
-        pruner = SimplePruner(model, None, None)
-        pruner.prepare()
-        pruner.enable_mask_update = True
+        pruner = SimplePruner(None)
+        pruner.prepare(model, None)
         self._check_pruner_valid_before_step(model, pruner, device)
         pruner.step()
         self._check_pruner_valid_after_step(model, pruner, {1}, device)
index 9ba05f2..55b8d70 100644 (file)
@@ -12,6 +12,8 @@ from .scheduler.lambda_scheduler import LambdaSL
 
 # Parametrizations
 from .sparsifier.utils import FakeSparsity
+from .sparsifier.utils import module_to_fqn
+from .sparsifier.utils import fqn_to_module
 
 # === Experimental ===
 
index 92e1945..d89b3cc 100644 (file)
@@ -1,6 +1,5 @@
 
 import abc
-import copy
 
 import torch
 from torch import nn
@@ -10,31 +9,15 @@ from torch.nn.modules.container import ModuleDict, ModuleList
 
 from .parametrization import PruningParametrization, ActivationReconstruction
 
+from torch.ao.sparsity import BaseSparsifier, fqn_to_module
+
 SUPPORTED_MODULES = {
     nn.Linear,
     nn.Conv2d
 }
 
-def _module_to_path(model, layer, prefix=''):
-    for name, child in model.named_children():
-        new_name = prefix + '.' + name
-        if child is layer:
-            return new_name
-        child_path = _module_to_path(child, layer, prefix=new_name)
-        if child_path is not None:
-            return child_path
-    return None
-
-def _path_to_module(model, path):
-    path = path.split('.')
-    for name in path:
-        model = getattr(model, name, None)
-        if model is None:
-            return None
-    return model
-
-
-class BasePruner(abc.ABC):
+
+class BasePruner(BaseSparsifier):
     r"""Base class for all pruners.
 
     Abstract methods that need to be implemented:
@@ -53,66 +36,8 @@ class BasePruner(abc.ABC):
             be updated.
 
     """
-    def __init__(self, model, config, defaults):
-        super().__init__()
-        self.config = config
-        self.defaults = defaults
-        if self.defaults is None:
-            self.defaults = dict()
-
-        self.module_groups = []
-        self.enable_mask_update = False
-        self.activation_handles = []
-        self.bias_handles = []
-
-        self.model = model
-        # If no config -- try getting all the supported layers
-        if self.config is None:
-            # Add all models to the config
-            self.config = []
-            stack = [model]
-            while stack:
-                module = stack.pop()
-                for name, child in module.named_children():
-                    if type(child) in SUPPORTED_MODULES:
-                        self.config.append(child)
-                    else:
-                        stack.append(child)
-
-        for module_config in self.config:
-            if isinstance(module_config, nn.Module):
-                module_config = {'module': module_config}
-            local_args = copy.deepcopy(self.defaults)
-            local_args.update(module_config)
-            module = local_args['module']
-            module_path = _module_to_path(self.model, module)
-            if module_path and module_path[0] == '.':
-                module_path = module_path[1:]
-            local_args['path'] = module_path
-            self.module_groups.append(local_args)
-
-    def __getstate__(self):
-        return {
-            'defaults': self.defaults,
-            'module_groups': self.module_groups,
-        }
-
-    def __setstate__(self, state):
-        self.__dict__.update(state)
-
-    def __repr__(self):
-        format_string = self.__class__.__name__ + ' ('
-        for i, sparse_args in enumerate(self.module_groups):
-            module = sparse_args['module']
-            format_string += '\n'
-            format_string += f'\tModule Group {i}\n'
-            format_string += f'\t    module: {module}\n'
-            for key in sorted(sparse_args.keys()):
-                if key == 'module':
-                    continue
-                format_string += f'\t    {key}: {sparse_args[key]}\n'
-        format_string += ')'
-        return format_string
+    def __init__(self, defaults):
+        super().__init__(defaults)
 
     def bias_hook(self, module, input, output):
         if getattr(module, '_bias', None) is not None:
@@ -122,12 +47,15 @@ class BasePruner(abc.ABC):
             output += bias
         return output
 
-    def prepare(self, use_path=False, *args, **kwargs):
+    def _prepare(self, use_path=False, *args, **kwargs):
         r"""Adds mask parametrization to the layer weight
         """
+        self.activation_handles = []  # store removable hook handles
+        self.bias_handles = []
+
         for config in self.module_groups:
             if use_path:
-                module = _path_to_module(self.model, config['path'])
+                module = fqn_to_module(self.model, config['fqn'])
             else:
                 module = config['module']
 
@@ -152,10 +80,10 @@ class BasePruner(abc.ABC):
                 module.bias = None
             self.bias_handles.append(module.register_forward_hook(self.bias_hook))
 
-    def convert(self, use_path=False, *args, **kwargs):
+    def squash_mask(self, use_path=False, *args, **kwargs):
         for config in self.module_groups:
             if use_path:
-                module = _path_to_module(self.model, config['path'])
+                module = fqn_to_module(self.model, config['fqn'])
             else:
                 module = config['module']
             parametrize.remove_parametrizations(module, 'weight',
@@ -166,17 +94,6 @@ class BasePruner(abc.ABC):
                 del module._buffers['mask']
             delattr(module, 'mask')
 
-    def step(self, use_path=True):
-        if not self.enable_mask_update:
-            return
-        with torch.no_grad():
-            for config in self.module_groups:
-                if use_path:
-                    module = _path_to_module(self.model, config['path'])
-                else:
-                    module = config['module']
-                self.update_mask(module, **config)
-
     @abc.abstractmethod
     def update_mask(self, layer, **kwargs):
         pass
index d6bc7d7..1d01b71 100644 (file)
@@ -8,30 +8,12 @@ import torch
 from torch import nn
 from torch.nn.utils import parametrize
 
-from .utils import FakeSparsity
+from .utils import FakeSparsity, module_to_fqn, fqn_to_module
 
 SUPPORTED_MODULES = {
     nn.Linear
 }
 
-def _module_to_fqn(model, layer, prefix=''):
-    for name, child in model.named_children():
-        new_name = prefix + '.' + name
-        if child is layer:
-            return new_name
-        child_path = _module_to_fqn(child, layer, prefix=new_name)
-        if child_path is not None:
-            return child_path
-    return None
-
-def _fqn_to_module(model, path):
-    path = path.split('.')
-    for name in path:
-        model = getattr(model, name, None)
-        if model is None:
-            return None
-    return model
-
 
 class BaseSparsifier(abc.ABC):
     r"""Base class for all sparsifiers.
@@ -136,7 +118,7 @@ class BaseSparsifier(abc.ABC):
         module_groups = copy.deepcopy(state_dict['module_groups'])
         states = state_dict['state']
         for fqn, s in states.items():
-            layer = _fqn_to_module(self.model, fqn)
+            layer = fqn_to_module(self.model, fqn)
             if strict and layer is None:
                 raise RuntimeError(f'Error loading {fqn} into the model')
 
@@ -186,7 +168,7 @@ class BaseSparsifier(abc.ABC):
             local_args = copy.deepcopy(self.defaults)
             local_args.update(module_config)
             module = local_args['module']
-            module_fqn = _module_to_fqn(model, module)
+            module_fqn = module_to_fqn(model, module)
             if module_fqn and module_fqn[0] == '.':
                 module_fqn = module_fqn[1:]
             local_args['fqn'] = module_fqn
index 6271a8d..3124b1b 100644 (file)
@@ -1,5 +1,23 @@
 from torch import nn
 
+def module_to_fqn(model, layer, prefix=''):
+    for name, child in model.named_children():
+        new_name = prefix + '.' + name
+        if child is layer:
+            return new_name
+        child_path = module_to_fqn(child, layer, prefix=new_name)
+        if child_path is not None:
+            return child_path
+    return None
+
+def fqn_to_module(model, path):
+    path = path.split('.')
+    for name in path:
+        model = getattr(model, name, None)
+        if model is None:
+            return None
+    return model
+
 # Parametrizations
 class FakeSparsity(nn.Module):
     r"""Parametrization for the weights. Should be attached to the 'weight' or