[pruner] add support for pruning BatchNorm2d (#63519)
authorKaren Zhou <kazhou@fb.com>
Wed, 25 Aug 2021 16:55:02 +0000 (09:55 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 16:56:19 +0000 (09:56 -0700)
commit83b132b112c2e035a23dcab4a88393209c4325ee
tree29c0c4c4f1b26052058ba0586bf1c6cae901434f
parentc1dfd58715c73dba3c089b2993e62d03a8647407
[pruner] add support for pruning BatchNorm2d (#63519)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63519

If the pruner should be pruning biases along with weights, then if the model has BatchNorm2d following pruned Conv2d layers, then the corresponding channels of the BatchNorm must also be pruned.

Specifically, they need to zeroed out, rather than fully removed, since in eager mode, the dimensions between layers need to be preserved.

To do this, we add a pruning parametrization called `ZeroesParametrization` which zeroes out pruned channels, rather than removing them.

The user must provide in the config, a tuple of the Conv2d and BatchNorm layers that go together. The `prepare` method will add the tuple to the `module_groups`; then it will add a PruningParametrization to the Conv2d layer, and a ZeroesParametrization to BatchNorm, and then set their pruned sets to be the same set. That way, during `step`, both masks are updated with the same pruned indices.

ghstack-source-id: 136562278

Test Plan:
`buck test mode/dev-nosan //caffe2/test:ao -- TestBasePruner`

https://pxl.cl/1N1P6

Reviewed By: z-a-f

Differential Revision: D30349855

fbshipit-source-id: 3199d3688d5a70963f9b32d7a8fdac3962ae6a65
test/ao/sparsity/test_pruner.py
torch/ao/sparsity/__init__.py
torch/ao/sparsity/experimental/pruner/base_pruner.py
torch/ao/sparsity/experimental/pruner/parametrization.py