[AutoTVM] Fix hang/crash issues on feature extraction (#3689)
authorLianmin Zheng <lianminzheng@gmail.com>
Fri, 2 Aug 2019 16:14:27 +0000 (00:14 +0800)
committerWuwei Lin <wuwei@apache.org>
Fri, 2 Aug 2019 16:14:27 +0000 (09:14 -0700)
* [AutoTVM] Fix hang/crash issues on feature extraction

* Update xgboost_cost_model.py

* fix lint

python/tvm/autotvm/tuner/xgboost_cost_model.py
src/autotvm/touch_extractor.cc

index e278957..2653651 100644 (file)
@@ -312,9 +312,16 @@ class XGBoostCostModel(CostModel):
             for i, fea in zip(need_extract, feas):
                 fea_cache[i] = fea
 
-        ret = np.empty((len(indexes), fea_cache[indexes[0]].shape[-1]), dtype=np.float32)
+        feature_len = None
+        for idx in indexes:
+            if fea_cache[idx] is not None:
+                feature_len = fea_cache[idx].shape[-1]
+                break
+
+        ret = np.empty((len(indexes), feature_len), dtype=np.float32)
         for i, ii in enumerate(indexes):
-            ret[i, :] = fea_cache[ii]
+            t = fea_cache[ii]
+            ret[i, :] = t if t is not None else 0
         return ret
 
     def __del__(self):
@@ -327,71 +334,88 @@ _extract_task = None
 
 def _extract_itervar_feature_index(index):
     """extract iteration var feature for an index in extract_space"""
-    config = _extract_space.get(index)
-    with _extract_target:
-        sch, args = _extract_task.instantiate(config)
-    fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
-    fea = np.concatenate((fea, list(config.get_other_option().values())))
-    return fea
+    try:
+        config = _extract_space.get(index)
+        with _extract_target:
+            sch, args = _extract_task.instantiate(config)
+        fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
+        fea = np.concatenate((fea, list(config.get_other_option().values())))
+        return fea
+    except Exception:  # pylint: disable=broad-except
+        return None
 
 def _extract_itervar_feature_log(arg):
     """extract iteration var feature for log items"""
-    inp, res = arg
-    config = inp.config
-    with inp.target:
-        sch, args = inp.task.instantiate(config)
-    fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
-    x = np.concatenate((fea, list(config.get_other_option().values())))
-
-    if res.error_no == 0:
-        y = inp.task.flop / np.mean(res.costs)
-    else:
-        y = 0.0
-    return x, y
+    try:
+        inp, res = arg
+        config = inp.config
+        with inp.target:
+            sch, args = inp.task.instantiate(config)
+        fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
+        x = np.concatenate((fea, list(config.get_other_option().values())))
+
+        if res.error_no == 0:
+            y = inp.task.flop / np.mean(res.costs)
+        else:
+            y = 0.0
+        return x, y
+    except Exception:  # pylint: disable=broad-except
+        return None
 
 def _extract_knob_feature_index(index):
     """extract knob feature for an index in extract_space"""
-    config = _extract_space.get(index)
-    return config.get_flatten_feature()
+    try:
+        config = _extract_space.get(index)
+        return config.get_flatten_feature()
+    except Exception:  # pylint: disable=broad-except
+        return None
 
 def _extract_knob_feature_log(arg):
     """extract knob feature for log items"""
-    inp, res = arg
-    config = inp.config
-    x = config.get_flatten_feature()
-
-    if res.error_no == 0:
-        with inp.target:  # necessary, for calculating flops of this task
-            inp.task.instantiate(config)
-        y = inp.task.flop / np.mean(res.costs)
-    else:
-        y = 0.0
-    return x, y
+    try:
+        inp, res = arg
+        config = inp.config
+        x = config.get_flatten_feature()
+
+        if res.error_no == 0:
+            with inp.target:  # necessary, for calculating flops of this task
+                inp.task.instantiate(config)
+            y = inp.task.flop / np.mean(res.costs)
+        else:
+            y = 0.0
+        return x, y
+    except Exception:  # pylint: disable=broad-except
+        return None
 
 def _extract_curve_feature_index(index):
     """extract sampled curve feature for an index in extract_space"""
-    config = _extract_space.get(index)
-    with _extract_target:
-        sch, args = _extract_task.instantiate(config)
-    fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
-    fea = np.concatenate((fea, list(config.get_other_option().values())))
-    return np.array(fea)
+    try:
+        config = _extract_space.get(index)
+        with _extract_target:
+            sch, args = _extract_task.instantiate(config)
+        fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
+        fea = np.concatenate((fea, list(config.get_other_option().values())))
+        return np.array(fea)
+    except Exception:  # pylint: disable=broad-except
+        return None
 
 def _extract_curve_feature_log(arg):
     """extract sampled curve feature for log items"""
-    inp, res = arg
-    config = inp.config
-    with inp.target:
-        sch, args = inp.task.instantiate(config)
-    fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
-    x = np.concatenate((fea, list(config.get_other_option().values())))
-
-    if res.error_no == 0:
-        y = inp.task.flop / np.mean(res.costs)
-    else:
-        y = 0.0
-    return x, y
-
+    try:
+        inp, res = arg
+        config = inp.config
+        with inp.target:
+            sch, args = inp.task.instantiate(config)
+        fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
+        x = np.concatenate((fea, list(config.get_other_option().values())))
+
+        if res.error_no == 0:
+            y = inp.task.flop / np.mean(res.costs)
+        else:
+            y = 0.0
+        return x, y
+    except Exception:  # pylint: disable=broad-except
+        return None
 
 def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
                     maximize=False, verbose_eval=True):
index 002b970..d40ebd6 100644 (file)
@@ -131,7 +131,9 @@ void TouchExtractor::ExitItervar_() {
   }
   itervar_stack_.pop_back();
 
-  topdown_product_ /= itervar_map[var].length;
+  int64_t length = itervar_map[var].length;
+  if (length != 0)
+      topdown_product_ /= length;
   int64_t bottomup_product = -1;
   for (auto kv : itervar_map[var].touch_feature) {
     bottomup_product = std::max(bottomup_product, kv.second.count * kv.second.reuse);