typedef std::map<int, std::vector<util::NormalizedBBox> > LabelBBox;
+ inline int getNumOfTargetClasses() {
+ unsigned numBackground =
+ (_backgroundLabelId >= 0 && _backgroundLabelId < _numClasses) ? 1 : 0;
+ return (_numClasses - numBackground);
+ }
+
bool getParameterDict(const LayerParams ¶ms,
const std::string ¶meterName,
DictValue& result)
LabelBBox::const_iterator label_bboxes = decodeBBoxes.find(label);
if (label_bboxes == decodeBBoxes.end())
CV_Error_(cv::Error::StsError, ("Could not find location predictions for label %d", label));
+ int limit = (getNumOfTargetClasses() == 1) ? _keepTopK : std::numeric_limits<int>::max();
if (_bboxesNormalized)
NMSFast_(label_bboxes->second, scores, _confidenceThreshold, _nmsThreshold, 1.0, _topK,
- indices[c], util::caffe_norm_box_overlap);
+ indices[c], util::caffe_norm_box_overlap, limit);
else
NMSFast_(label_bboxes->second, scores, _confidenceThreshold, _nmsThreshold, 1.0, _topK,
- indices[c], util::caffe_box_overlap);
+ indices[c], util::caffe_box_overlap, limit);
numDetections += indices[c].size();
}
if (_keepTopK > -1 && numDetections > (size_t)_keepTopK)
// score_threshold: a threshold used to filter detection results.
// nms_threshold: a threshold used in non maximum suppression.
// top_k: if not > 0, keep at most top_k picked indices.
+// limit: early terminate once the # of picked indices has reached it.
// indices: the kept indices of bboxes after nms.
template <typename BoxType>
inline void NMSFast_(const std::vector<BoxType>& bboxes,
const std::vector<float>& scores, const float score_threshold,
const float nms_threshold, const float eta, const int top_k,
- std::vector<int>& indices, float (*computeOverlap)(const BoxType&, const BoxType&))
+ std::vector<int>& indices,
+ float (*computeOverlap)(const BoxType&, const BoxType&),
+ int limit = std::numeric_limits<int>::max())
{
CV_Assert(bboxes.size() == scores.size());
float overlap = computeOverlap(bboxes[idx], bboxes[kept_idx]);
keep = overlap <= adaptive_threshold;
}
- if (keep)
+ if (keep) {
indices.push_back(idx);
+ if (indices.size() >= limit) {
+ break;
+ }
+ }
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}