From 77a7285764334cc50eb573c3d2614ffefba4a29f Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Fri, 22 Mar 2019 00:49:11 -0700 Subject: [PATCH] add more Python interface functions to make quantization simpler (#18246) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18246 Simplifies histogram collection and quantization process. Histogram collection before this diff was something like this ``` from caffe2.quantization.server import dnnlowp_pybind11 ... dnnlowp_pybind11.ObserveHistogramOfOutput(hist_file) for ... workspace.RunNet(predict_net) dnnlowp_pybind11.ClearNetObservers() # This is to trigger Stop function in the observer to dump out histogram file but this can have unintended consequence of also clearing all the other useful observers we attached ``` After this diff we can ``` workspace.CreateNet(predict_net) # Note we need to create net to have a net to attach observer histogram_observer = dnnlowp_pybind11.AddHistogramObserver(predic_net, hist_file) for ... workspace.RunNet(predict_net) predict_net.RemoveObserver(histogram_observer) ``` Choosing quantization parameters of weights before this diff was something like this ``` dnnlowp_pybind11.ObserveHistogramOfOutput(weight_hist_file) workspace.RunNetOnce(init_net) dnnlowp_pybind11.ClearNetObservers() # Has same issue as the histogram collection example above dnnlowp_pybind11.RegisterQuantizationParamsWithHistogram( weight_hist_file, is_weight=True, qparams_output_file_name=qparams_file ) workspace.CreateNet(init_net, overwrite=True) dnnlowp_pybind11.ClearNetObservers() logger.info("Loading quantization params from {}".format(qparams_file)) blobs_to_qparams = {} with open(qparams_file) as f: lines = f.readlines() for line in lines: op_id, op_type, output_id, tensor_name, mini, maxi, scale, zero_point, precision = ( line.split() ) op_id = int(op_id) output_id = int(output_id) op = net.Proto().op[op_id] if op_type != op.type or op.output[output_id] != tensor_name: print( "Corrupt qparams file {} {} {} {} {}".format( qparams_file, op_type, op.type, op.output[output_id], tensor_name ) ) blobs_to_qparams[tensor_name] = QuantizationParam(float(scale), int(zero_point)) ``` After this diff this can be simplified to ``` blobs_to_qparams = {} for op in init_net.Proto().op: for output in op.output: scale, zero_point = dnnlowp_pybind11.ChooseQuantizationParams(output) blobs_to_qparams[output] = QuantizationParam(scale, zero_point) ``` Reviewed By: dskhudia Differential Revision: D14544694 fbshipit-source-id: 4fd06cd63256201e2e9d15c39f503138d1be53c2 --- caffe2/quantization/server/pybind.cc | 89 ++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/caffe2/quantization/server/pybind.cc b/caffe2/quantization/server/pybind.cc index 05cb64c..02e9a86 100644 --- a/caffe2/quantization/server/pybind.cc +++ b/caffe2/quantization/server/pybind.cc @@ -2,6 +2,13 @@ #include "activation_distribution_observer.h" #include "caffe2_dnnlowp_utils.h" +namespace caffe2 { +namespace python { +// defined in caffe2/python/pybind_state.cc +Workspace* GetCurrentWorkspace(); +} // namespace python +} // namespace caffe2 + PYBIND11_MODULE(dnnlowp_pybind11, m) { using namespace std; using namespace caffe2; @@ -34,6 +41,61 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) { pybind11::arg("mul_nets") = false); m.def( + "AddHistogramObserver", + [](const string& net_name, + const string& out_file_name, + int dump_freq, + bool mul_nets) { + Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(gWorkspace); + CAFFE_ENFORCE( + gWorkspace->GetNet(net_name), "Can't find net ", net_name); + pybind11::gil_scoped_release g; + + NetBase* net = gWorkspace->GetNet(net_name); + const Observable::Observer* observer = nullptr; + + observer = net->AttachObserver(make_unique( + net, out_file_name, 2048, dump_freq, mul_nets)); + + CAFFE_ENFORCE(observer != nullptr); + return pybind11::cast(observer); + }, + pybind11::arg("net_name"), + pybind11::arg("out_file_name"), + pybind11::arg("dump_freq") = -1, + pybind11::arg("mul_nets") = false); + + m.def( + "ChooseQuantizationParams", + [](const std::string& blob_name) { + Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(gWorkspace); + pybind11::gil_scoped_release g; + + const auto* blob = gWorkspace->GetBlob(blob_name); + if (blob == nullptr) { + LOG(WARNING) << "Can't find blob " << blob_name; + } else if (BlobIsTensorType(*blob, CPU)) { + LOG(WARNING) << "Blob " << blob_name << " is not a tensor"; + } else { + const auto& tensor = blob->template Get(); + if (tensor.IsType()) { + dnnlowp::QuantizationFactory* qfactory = + dnnlowp::QuantizationFactory::GetDefaultInstance(); + dnnlowp::TensorQuantizationParams qparams = + qfactory->ChooseQuantizationParams( + tensor.data(), tensor.size(), true /*weight*/); + return std::tuple(qparams.scale, qparams.zero_point); + } else { + LOG(WARNING) << "Blob " << blob_name << " is not a float tensor"; + } + } + return std::tuple(1.0, 0); + }, + pybind11::arg("blob_name")); + + m.def( "RegisterQuantizationParams", [](const string& min_max_file_name, bool is_weight, @@ -67,6 +129,33 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) { pybind11::arg("qparams_output_file_name") = ""); m.def( + "AddRegisterQuantizationParamsWithHistogramObserver", + [](const string& net_name, + const string& histogram_file_name, + int is_weight, + const string& qparams_output_file_name) { + Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(gWorkspace); + CAFFE_ENFORCE( + gWorkspace->GetNet(net_name), "Can't find net ", net_name); + pybind11::gil_scoped_release g; + + NetBase* net = gWorkspace->GetNet(net_name); + const Observable::Observer* observer = nullptr; + + observer = net->AttachObserver( + make_unique( + net, histogram_file_name, is_weight, qparams_output_file_name)); + + CAFFE_ENFORCE(observer != nullptr); + return pybind11::cast(observer); + }, + pybind11::arg("net_name"), + pybind11::arg("histogram_file_name"), + pybind11::arg("is_weight") = false, + pybind11::arg("qparams_output_file_name") = ""); + + m.def( "AddScaleZeroOffsetArgumentsWithHistogram", [](const pybind11::bytes& net_def_bytes, const string& histogram_file_name) { -- 2.7.4