From 2440c9ced823b449dcb5a718bc485d34f80191d3 Mon Sep 17 00:00:00 2001 From: masahi Date: Fri, 3 Jan 2020 21:23:09 +0900 Subject: [PATCH] [Quantization] Make calibration faster and more memory usage friendly (#4589) * Use memory efficient calibrate * Fixed indexing * add cpp kl stub * ported KL cpp from mxnet * Fixed std::distance arguments order * remove python implementation * fix lint and indent * fix indent * refactoring * fix lint * fix for i386 --- python/tvm/relay/quantize/_calibrate.py | 85 ++++++++++-------- python/tvm/relay/quantize/kl_divergence.py | 100 +++------------------ python/tvm/relay/quantize/quantize.py | 3 +- src/relay/pass/quantize/calibrate.cc | 122 ++++++++++++++++++++++++++ src/relay/pass/quantize/quantize.h | 2 + tests/python/relay/test_pass_auto_quantize.py | 11 +++ 6 files changed, 199 insertions(+), 124 deletions(-) diff --git a/python/tvm/relay/quantize/_calibrate.py b/python/tvm/relay/quantize/_calibrate.py index 21254fa..d904fed 100644 --- a/python/tvm/relay/quantize/_calibrate.py +++ b/python/tvm/relay/quantize/_calibrate.py @@ -33,24 +33,7 @@ from ...contrib import graph_runtime from .kl_divergence import _find_scale_by_kl -def collect_stats(mod, dataset): - """Given an annotated graph, create a profile graph to collect profile data from the - calibration dataset. This pass collects simulated_quantize op input into a tuple. - Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile - graph. - - Parameters - ---------- - mod: Module - The simulation graph after annotation. - - Returns - ------- - ret: list of ndarray - List of output data of each layer - """ - - logging.info("collecting statistics for calibration...") +def _get_profile_runtime(mod): func = mod['main'] func = _quantize.CreateStatsCollector(func) @@ -63,30 +46,61 @@ def collect_stats(mod, dataset): with _transform.build_config(opt_level=3): graph, lib, params = _build_module.build(func, target=target) - outputs = [] runtime = graph_runtime.create(graph, lib, ctx) runtime.set_input(**params) + return runtime + + +def collect_stats(mod, dataset, chunk_by=-1): + """Given an annotated graph, create a profile graph to collect profile data from the + calibration dataset. This pass collects simulated_quantize op input into a tuple. + Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile + graph. + + Parameters + ---------- + mod: Module + The simulation graph after annotation. + + dataset: Iterable[NDArray] + The calibration dataset. + + chunk_by: optional, int + The size of chunk to be returned in one iteration. It is meant to be + used for reducing memory usage. If not specified, return samples for + all layers in one chunk. + + Returns + ------- + ret: Iterable[list of ndarray] + List of output data of each layer, chunked by the chunk_by parameter + """ + logging.info("collecting statistics for calibration...") + runtime = _get_profile_runtime(mod) num_outputs = runtime.get_num_outputs() - outputs = [[] for i in range(num_outputs)] + chunk_by = num_outputs if chunk_by == -1 else chunk_by - for batch in dataset: - runtime.set_input(**batch) - runtime.run() - for i in range(num_outputs): - output = runtime.get_output(i).asnumpy() - outputs[i].append(output) - for i in range(num_outputs): - outputs[i] = np.concatenate(outputs[i]).reshape(-1) - return outputs + for i in range(0, num_outputs, chunk_by): + outputs = [[] for i in range(min(chunk_by, num_outputs - i))] + for batch in dataset: + runtime.set_input(**batch) + runtime.run() + for j in range(i, min(i+chunk_by, num_outputs)): + outputs[j-i].append(runtime.get_output(j).asnumpy()) + yield [np.concatenate(output).reshape(-1) for output in outputs] -def _kl_scale(stats): - with mp.Pool() as pool: +def _kl_scale(mod, dataset): + cfg = quantize.current_qconfig() + chunk_by = cfg.calibrate_chunk_by + scales = [] + for samples in collect_stats(mod, dataset, chunk_by): logging.info("finding threshold with kl for calibration...") - scales = list(pool.map(_find_scale_by_kl, stats)) + with mp.Pool() as pool: + scales += list(pool.map(_find_scale_by_kl, samples)) - def func(sq_call): # pylint: disable=unused-argument + def func(_): scale = scales[func.scale_idx] func.scale_idx += 1 return scale @@ -168,13 +182,12 @@ def calibrate(dataset=None): ret: Function The module pass function. """ - def wrapped_func(mod, ctx): # pylint: disable=unused-argument + def wrapped_func(mod, _): """make transform.module pass happy""" cfg = quantize.current_qconfig() if cfg.calibrate_mode == 'kl_divergence': - stats = collect_stats(mod, dataset) - input_scale_func = _kl_scale(stats) + input_scale_func = _kl_scale(mod, dataset) elif cfg.calibrate_mode == 'global_scale': input_scale_func = _global_scale else: diff --git a/python/tvm/relay/quantize/kl_divergence.py b/python/tvm/relay/quantize/kl_divergence.py index 2feb514..6492750 100644 --- a/python/tvm/relay/quantize/kl_divergence.py +++ b/python/tvm/relay/quantize/kl_divergence.py @@ -16,36 +16,14 @@ # under the License. """Find optimal scale for quantization by minimizing KL-divergence""" -try: - from scipy import stats -except ImportError: - stats = None - +import ctypes import numpy as np - -def _smooth_distribution(p, eps=0.0001): - """Given a discrete distribution (may have not been normalized to 1), - smooth it by replacing zeros with eps multiplied by a scaling factor and taking the - corresponding amount off the non-zero values. - Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf - """ - is_zeros = (p == 0).astype(np.float32) - is_nonzeros = (p != 0).astype(np.float32) - n_zeros = is_zeros.sum() - n_nonzeros = p.size - n_zeros - if not n_nonzeros: - raise ValueError('The discrete probability distribution is malformed. All entries are 0.') - eps1 = eps * float(n_zeros) / float(n_nonzeros) - assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1) - hist = p.astype(np.float32) - hist += eps * is_zeros + (-eps1) * is_nonzeros - assert (hist <= 0).sum() == 0 - return hist +from . import _quantize -# pylint: disable=invalid-name -def _find_scale_by_kl(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): +def _find_scale_by_kl(arr, quantized_dtype='int8', + num_bins=8001, num_quantized_bins=255): """Given a tensor, find the optimal threshold for quantizing it. The reference distribution is `q`, and the candidate distribution is `p`. `q` is a truncated version of the original distribution. @@ -54,73 +32,21 @@ def _find_scale_by_kl(arr, quantized_dtype='int8', num_bins=8001, num_quantized_ http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ assert isinstance(arr, np.ndarray) - assert stats is not None, "scipy needs to be installed for \ - utilizing kl calibration during quantization" - min_val = np.min(arr) max_val = np.max(arr) - th = max(abs(min_val), abs(max_val)) + thres = max(abs(min_val), abs(max_val)) if min_val >= 0 and quantized_dtype in ['uint8']: # We need to move negative bins to positive bins to fit uint8 range. num_quantized_bins = num_quantized_bins * 2 + 1 - hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th)) - zero_bin_idx = num_bins // 2 - num_half_quantized_bins = num_quantized_bins // 2 - - thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2) - divergence = np.zeros_like(thresholds) - quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32) - # i means the number of bins on half axis excluding the zero bin. - for i in range(num_quantized_bins // 2, - num_bins // 2 + 1): - p_bin_idx_start = zero_bin_idx - i - p_bin_idx_stop = zero_bin_idx + i + 1 - thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop] - sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop] - - # generate reference distribution p - p = sliced_nd_hist.copy() - assert p.size % 2 == 1 - assert p.size >= num_quantized_bins - # put left outlier count in p[0] - left_outlier_count = np.sum(hist[0:p_bin_idx_start]) - p[0] += left_outlier_count - # put right outlier count in p[-1] - right_outlier_count = np.sum(hist[p_bin_idx_stop:]) - p[-1] += right_outlier_count - # is_nonzeros[k] indicates whether hist[k] is nonzero - is_nonzeros = (p != 0).astype(np.int32) + def get_pointer(arr, ctypes_type): + ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes_type)) + return ctypes.cast(ptr, ctypes.c_void_p) - # calculate how many bins should be merged to generate quantized distribution q - num_merged_bins = sliced_nd_hist.size // num_quantized_bins - # merge hist into num_quantized_bins bins - for j in range(num_quantized_bins): - start = j * num_merged_bins - stop = start + num_merged_bins - quantized_bins[j] = sliced_nd_hist[start:stop].sum() - quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() - # expand quantized_bins into p.size bins - q = np.zeros(sliced_nd_hist.size, dtype=np.float32) - for j in range(num_quantized_bins): - start = j * num_merged_bins - if j == num_quantized_bins - 1: - stop = len(is_nonzeros) - else: - stop = start + num_merged_bins - norm = is_nonzeros[start:stop].sum() - if norm != 0: - q[start:stop] = float(quantized_bins[j]) / float(norm) - q[p == 0] = 0 - p = _smooth_distribution(p) - # There is a chance that q is an invalid probability distribution. - try: - q = _smooth_distribution(q) - except ValueError: - divergence[i - num_half_quantized_bins] = float("inf") - divergence[i - num_half_quantized_bins] = stats.entropy(p, q) + hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-thres, thres)) + hist_ptr = get_pointer(hist.astype(np.int32), ctypes.c_int) + hist_edges_ptr = get_pointer(hist_edges, ctypes.c_float) - min_divergence_idx = np.argmin(divergence) - opt_th = thresholds[min_divergence_idx] - return opt_th + return _quantize.FindScaleByKLMinimization(hist_ptr, hist_edges_ptr, + num_bins, num_quantized_bins) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 1d60145..ac5387c 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -81,7 +81,8 @@ class QConfig(NodeBase): "do_simulation": False, "round_for_shift": True, "debug_enabled_ops": None, - "rounding": "UPWARD" + "rounding": "UPWARD", + "calibrate_chunk_by": -1, } # pylint: disable=no-member diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc index f9893f5..bcf82c0 100644 --- a/src/relay/pass/quantize/calibrate.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -26,12 +26,122 @@ #include #include #include +#include #include "./quantize.h" namespace tvm { namespace relay { namespace quantize { +// KL divergence minimization code is adapted from MXNet. +// The original one is in incubator-mxnet/src/operator/quantization/calibrate.cc +static std::vector SmoothDistribution(const std::vector& p, + const float eps = 0.0001) { + std::vector is_zeros(p.size()); + std::vector is_nonzeros(p.size()); + { + auto it = p.begin(); + std::generate(is_zeros.begin(), is_zeros.end(), + [&it]() { return static_cast(*(it++) == 0.f); }); + } + { + auto it = p.begin(); + std::generate(is_nonzeros.begin(), is_nonzeros.end(), + [&it]() { return static_cast(*(it++) != 0.f); }); + } + size_t n_zeros = std::accumulate(is_zeros.begin(), is_zeros.end(), 0); + size_t n_nonzeros = p.size() - n_zeros; + if (!n_nonzeros) { + // The discrete probability distribution is malformed. All entries are 0. + return std::vector(); + } + float eps1 = eps * static_cast(n_zeros) / static_cast(n_nonzeros); + if (eps1 >= 1.0) return std::vector(); + auto ret = p; + for (size_t i = 0; i < p.size(); i++) { + ret[i] += eps * is_zeros[i] - eps1 * is_nonzeros[i]; + } + return ret; +} + +static float ComputeEntropy(float* p, float* q, size_t size) { + float p_sum = std::accumulate(p, p+size, 0.f); + float q_sum = std::accumulate(q, q+size, 0.f); + float ret = 0; + for (size_t i = 0; i < size; i++) { + CHECK(p[i] > 0 && q[i] > 0); + p[i] /= p_sum; + q[i] /= q_sum; + if (p[i] && q[i]) ret += p[i] * std::log(p[i] / q[i]); + } + return ret; +} + +float MinimizeKL(const std::vector& hist, + const std::vector& hist_edges, + int num_bins, int num_quantized_bins) { + const int zero_bin_idx = num_bins / 2; + const int num_half_quantized_bins = num_quantized_bins / 2; + std::vector thresholds(num_bins / 2 + 1 - num_quantized_bins / 2, 0.f); + std::vector divergence(thresholds.size(), 0.f); + std::vector quantized_bins(num_quantized_bins, 0); + for (int i = num_quantized_bins / 2; i < zero_bin_idx + 1; ++i) { + const int p_bin_idx_start = zero_bin_idx - i; + const int p_bin_idx_stop = zero_bin_idx + i + 1; + thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop]; + + std::vector sliced_nd_hist(p_bin_idx_stop - p_bin_idx_start); + std::vector p(sliced_nd_hist.size()); + p[0] = 0; + p.back() = 0; + for (int j = 0; j < num_bins; j++) { + if (j <= p_bin_idx_start) { + p[0] += hist[j]; + } else if (j >= p_bin_idx_stop) { + p.back() += hist[j]; + } else { + sliced_nd_hist[j - p_bin_idx_start] = hist[j]; + p[j - p_bin_idx_start] = hist[j]; + } + } + // calculate how many bins should be merged to generate quantized distribution q + const auto num_merged_bins = sliced_nd_hist.size() / num_quantized_bins; + for (int j = 0; j < num_quantized_bins; j++) { + const int start = j * num_merged_bins; + const int stop = (j + 1) * num_merged_bins; + quantized_bins[j] = + std::accumulate(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop, 0); + } + quantized_bins.back() += std::accumulate( + sliced_nd_hist.begin() + static_cast(num_quantized_bins * num_merged_bins), + sliced_nd_hist.end(), 0); + // expand quantized_bins into p.size bins + std::vector q(sliced_nd_hist.size(), 0); + for (int j = 0; j < num_quantized_bins; j++) { + const int start = j * num_merged_bins; + const int stop = (j == num_quantized_bins - 1) ? q.size() : ((j + 1) * num_merged_bins); + int norm = std::count_if(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop, + [](size_t i) { return i != 0; }); + if (norm) { + for (int k = start; k < stop; k++) { + if (p[k]) q[k] = quantized_bins[j] / norm; + } + } + } + p = SmoothDistribution(p); + q = SmoothDistribution(q); + + if (!q.size()) { + divergence[i - num_half_quantized_bins] = std::numeric_limits::infinity(); + } else { + divergence[i - num_half_quantized_bins] = ComputeEntropy(p.data(), q.data(), p.size()); + } + } + auto min_divergence_idx = std::distance(divergence.begin(), + std::min_element(divergence.begin(), divergence.end())); + return thresholds[min_divergence_idx];; +} + class StatsCollector : private ExprMutator { public: StatsCollector() : simulated_quantize_op_(Op::Get("relay.op.annotation.simulated_quantize")) {} @@ -95,6 +205,18 @@ Expr CreateStatsCollector(const Expr& expr) { TVM_REGISTER_API("relay._quantize.CreateStatsCollector") .set_body_typed(CreateStatsCollector); + +TVM_REGISTER_API("relay._quantize.FindScaleByKLMinimization") +.set_body([](TVMArgs args, TVMRetValue *ret) { + int* hist_ptr = static_cast(static_cast(args[0])); + float* hist_edges_ptr = static_cast(static_cast(args[1])); + int num_bins = args[2]; + int num_quantized_bins = args[3]; + std::vector hist(hist_ptr, hist_ptr + num_bins); + std::vector hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1); + ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins); +}); + } // namespace quantize } // namespace relay } // namespace tvm diff --git a/src/relay/pass/quantize/quantize.h b/src/relay/pass/quantize/quantize.h index bfb7653..4f6cd2d 100644 --- a/src/relay/pass/quantize/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -78,6 +78,7 @@ class QConfigNode : public Object { bool round_for_shift = true; Array debug_enabled_ops = Array(ObjectPtr(nullptr)); std::string rounding = "UPWARD"; + int calibrate_chunk_by = -1; void VisitAttrs(AttrVisitor* v) { v->Visit("nbit_input", &nbit_input); @@ -94,6 +95,7 @@ class QConfigNode : public Object { v->Visit("round_for_shift", &round_for_shift); v->Visit("debug_enabled_ops", &debug_enabled_ops); v->Visit("rounding", &rounding); + v->Visit("calibrate_chunk_by", &calibrate_chunk_by); } static constexpr const char* _type_key = "relay.quantize.QConfig"; diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index 5b2e368..443d2e4 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -67,7 +67,18 @@ def test_calibrate_target(create_target=False): relay.quantize.quantize(mod, params, dataset) +def test_calibrate_memory_bound(): + mod, params = testing.resnet.get_workload(num_layers=18) + dataset = get_calibration_dataset("data") + import multiprocessing + num_cpu = multiprocessing.cpu_count() + with relay.quantize.qconfig(calibrate_mode="kl_divergence", + calibrate_chunk_by=num_cpu): + relay.quantize.quantize(mod, params, dataset) + + if __name__ == "__main__": test_mul_rewrite() test_calibrate_target(False) test_calibrate_target(True) + test_calibrate_memory_bound() -- 2.7.4