add more Python interface functions to make quantization simpler (#18246)
authorJongsoo Park <jongsoo@fb.com>
Fri, 22 Mar 2019 07:49:11 +0000 (00:49 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 07:52:24 +0000 (00:52 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18246

Simplifies histogram collection and quantization process.

Histogram collection before this diff was something like this
```
from caffe2.quantization.server import dnnlowp_pybind11
...
dnnlowp_pybind11.ObserveHistogramOfOutput(hist_file)
for ...
   workspace.RunNet(predict_net)
dnnlowp_pybind11.ClearNetObservers()  # This is to trigger Stop function in the observer to dump out histogram file but this can have unintended consequence of also clearing all the other useful observers we attached
```

After this diff we can
```
workspace.CreateNet(predict_net)  # Note we need to create net to have a net to attach observer
histogram_observer = dnnlowp_pybind11.AddHistogramObserver(predic_net, hist_file)
for ...
   workspace.RunNet(predict_net)
predict_net.RemoveObserver(histogram_observer)
```

Choosing quantization parameters of weights before this diff was something like this
```
dnnlowp_pybind11.ObserveHistogramOfOutput(weight_hist_file)
workspace.RunNetOnce(init_net)
dnnlowp_pybind11.ClearNetObservers() # Has same issue as the histogram collection example above

dnnlowp_pybind11.RegisterQuantizationParamsWithHistogram(
    weight_hist_file, is_weight=True, qparams_output_file_name=qparams_file
)
workspace.CreateNet(init_net, overwrite=True)
dnnlowp_pybind11.ClearNetObservers()

logger.info("Loading quantization params from {}".format(qparams_file))
blobs_to_qparams = {}
with open(qparams_file) as f:
    lines = f.readlines()
for line in lines:
    op_id, op_type, output_id, tensor_name, mini, maxi, scale, zero_point, precision = (
        line.split()
    )
    op_id = int(op_id)
    output_id = int(output_id)
    op = net.Proto().op[op_id]
    if op_type != op.type or op.output[output_id] != tensor_name:
        print(
            "Corrupt qparams file {} {} {} {} {}".format(
                qparams_file, op_type, op.type, op.output[output_id], tensor_name
            )
        )
    blobs_to_qparams[tensor_name] = QuantizationParam(float(scale), int(zero_point))

```

After this diff this can be simplified to
```
blobs_to_qparams = {}
for op in init_net.Proto().op:
    for output in op.output:
        scale, zero_point = dnnlowp_pybind11.ChooseQuantizationParams(output)
        blobs_to_qparams[output] = QuantizationParam(scale, zero_point)
```

Reviewed By: dskhudia

Differential Revision: D14544694

fbshipit-source-id: 4fd06cd63256201e2e9d15c39f503138d1be53c2

caffe2/quantization/server/pybind.cc

index 05cb64c..02e9a86 100644 (file)
@@ -2,6 +2,13 @@
 #include "activation_distribution_observer.h"
 #include "caffe2_dnnlowp_utils.h"
 
+namespace caffe2 {
+namespace python {
+// defined in caffe2/python/pybind_state.cc
+Workspace* GetCurrentWorkspace();
+} // namespace python
+} // namespace caffe2
+
 PYBIND11_MODULE(dnnlowp_pybind11, m) {
   using namespace std;
   using namespace caffe2;
@@ -34,6 +41,61 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) {
       pybind11::arg("mul_nets") = false);
 
   m.def(
+      "AddHistogramObserver",
+      [](const string& net_name,
+         const string& out_file_name,
+         int dump_freq,
+         bool mul_nets) {
+        Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace();
+        CAFFE_ENFORCE(gWorkspace);
+        CAFFE_ENFORCE(
+            gWorkspace->GetNet(net_name), "Can't find net ", net_name);
+        pybind11::gil_scoped_release g;
+
+        NetBase* net = gWorkspace->GetNet(net_name);
+        const Observable<NetBase>::Observer* observer = nullptr;
+
+        observer = net->AttachObserver(make_unique<HistogramNetObserver>(
+            net, out_file_name, 2048, dump_freq, mul_nets));
+
+        CAFFE_ENFORCE(observer != nullptr);
+        return pybind11::cast(observer);
+      },
+      pybind11::arg("net_name"),
+      pybind11::arg("out_file_name"),
+      pybind11::arg("dump_freq") = -1,
+      pybind11::arg("mul_nets") = false);
+
+  m.def(
+      "ChooseQuantizationParams",
+      [](const std::string& blob_name) {
+        Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace();
+        CAFFE_ENFORCE(gWorkspace);
+        pybind11::gil_scoped_release g;
+
+        const auto* blob = gWorkspace->GetBlob(blob_name);
+        if (blob == nullptr) {
+          LOG(WARNING) << "Can't find blob " << blob_name;
+        } else if (BlobIsTensorType(*blob, CPU)) {
+          LOG(WARNING) << "Blob " << blob_name << " is not a tensor";
+        } else {
+          const auto& tensor = blob->template Get<Tensor>();
+          if (tensor.IsType<float>()) {
+            dnnlowp::QuantizationFactory* qfactory =
+                dnnlowp::QuantizationFactory::GetDefaultInstance();
+            dnnlowp::TensorQuantizationParams qparams =
+                qfactory->ChooseQuantizationParams(
+                    tensor.data<float>(), tensor.size(), true /*weight*/);
+            return std::tuple<float, int>(qparams.scale, qparams.zero_point);
+          } else {
+            LOG(WARNING) << "Blob " << blob_name << " is not a float tensor";
+          }
+        }
+        return std::tuple<float, int>(1.0, 0);
+      },
+      pybind11::arg("blob_name"));
+
+  m.def(
       "RegisterQuantizationParams",
       [](const string& min_max_file_name,
          bool is_weight,
@@ -67,6 +129,33 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) {
       pybind11::arg("qparams_output_file_name") = "");
 
   m.def(
+      "AddRegisterQuantizationParamsWithHistogramObserver",
+      [](const string& net_name,
+         const string& histogram_file_name,
+         int is_weight,
+         const string& qparams_output_file_name) {
+        Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace();
+        CAFFE_ENFORCE(gWorkspace);
+        CAFFE_ENFORCE(
+            gWorkspace->GetNet(net_name), "Can't find net ", net_name);
+        pybind11::gil_scoped_release g;
+
+        NetBase* net = gWorkspace->GetNet(net_name);
+        const Observable<NetBase>::Observer* observer = nullptr;
+
+        observer = net->AttachObserver(
+            make_unique<RegisterQuantizationParamsWithHistogramNetObserver>(
+                net, histogram_file_name, is_weight, qparams_output_file_name));
+
+        CAFFE_ENFORCE(observer != nullptr);
+        return pybind11::cast(observer);
+      },
+      pybind11::arg("net_name"),
+      pybind11::arg("histogram_file_name"),
+      pybind11::arg("is_weight") = false,
+      pybind11::arg("qparams_output_file_name") = "");
+
+  m.def(
       "AddScaleZeroOffsetArgumentsWithHistogram",
       [](const pybind11::bytes& net_def_bytes,
          const string& histogram_file_name) {