handle a rare case of histogram min is inf/nan (#18239)
authorJongsoo Park <jongsoo@fb.com>
Mon, 1 Apr 2019 04:25:17 +0000 (21:25 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 1 Apr 2019 04:32:54 +0000 (21:32 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18239

When min is inf or nan, we get UBSAN errors

Reviewed By: csummersea

Differential Revision: D14537668

fbshipit-source-id: e70ffb5ecd2b10793356070c69fdabf8f25b203e

caffe2/quantization/server/activation_distribution_observer.cc
caffe2/quantization/server/activation_distribution_observer.h

index 900f730..37f7180 100644 (file)
@@ -342,6 +342,10 @@ void HistogramNetObserver::DumpAndReset_(
                      << " has an empty range: min " << hist->Min()
                      << " and max " << hist->Max();
       }
+      if (hist->GetHistogram()->empty()) {
+        LOG(WARNING) << "Histogram of "
+                     << info->min_max_info.tensor_infos[i].name << " is empty";
+      }
 
       ostringstream ost;
       ost << op_index << " " << info->min_max_info.type << " " << i << " "
@@ -352,15 +356,11 @@ void HistogramNetObserver::DumpAndReset_(
         ost << " " << c;
       }
 
-      f << ost.str() << endl;
       if (print_total_min_max) {
         LOG(INFO) << this << " " << ost.str();
       }
 
-      if (hist->GetHistogram()->empty()) {
-        LOG(WARNING) << "Histogram of "
-                     << info->min_max_info.tensor_infos[i].name << " is empty";
-      }
+      f << ost.str() << endl;
 
       if (!print_total_min_max) {
         info->histograms[i] = DynamicHistogram(hist->GetHistogram()->size());
@@ -575,7 +575,8 @@ RegisterQuantizationParamsWithHistogramNetObserver::
         qparams = qfactory->ChooseQuantizationParams(hist, is_weight);
       } else {
         qparams.scale = 0.1f;
-        qparams.zero_point = -min / qparams.scale;
+        qparams.zero_point =
+            (isinf(min) || isnan(min)) ? 0 : (-min / qparams.scale);
         qparams.precision = 8;
       }
 
index 6162e62..72ecab3 100644 (file)
@@ -100,6 +100,13 @@ class HistogramObserver final : public ObserverBase<OperatorBase> {
 
 class HistogramNetObserver final : public NetObserver {
  public:
+  /**
+   * @params mul_nets true if we expect multiple nets with the same name so
+   *                  we include extra information in the file name to
+   *                  distinghuish them
+   * @params dump_freq if not -1 we dump histogram every dump_freq invocation
+   *                   of the net
+   */
   explicit HistogramNetObserver(
       NetBase* subject,
       const std::string& out_file_name,