diff --git a/brambox/boxes/statistics/pr.py b/brambox/boxes/statistics/pr.py index 1e687906c7ae0c82e604558d1b68788588a1a682..c2302f8e9d74a2803e49a1f8234816a7cd53aa9f 100644 --- a/brambox/boxes/statistics/pr.py +++ b/brambox/boxes/statistics/pr.py @@ -5,13 +5,9 @@ # # Functions for generating PR-curve values and calculating average precision # - -import math -from statistics import mean import numpy as np import scipy.interpolate - -from .util import * +from .util import match_detections __all__ = ['pr', 'ap'] @@ -54,7 +50,7 @@ def ap(precision, recall, num_of_samples=100): p = np.array(precision) r = np.array(recall) p_start = p[np.argmin(r)] - samples = np.arange(0., 1., 1.0/num_of_samples) + samples = np.arange(0., 1., 1.0 / num_of_samples) interpolated = scipy.interpolate.interp1d(r, p, fill_value=(p_start, 0.), bounds_error=False)(samples) avg = sum(interpolated) / len(interpolated) elif len(precision) > 0 and len(recall) > 0: diff --git a/brambox/boxes/statistics/util.py b/brambox/boxes/statistics/util.py index ddf2f631c4af0e06608e4e3e48ab3d16121f2da5..7dea7599ed92a94fe6fe3ebe6e40041e3e39d427 100644 --- a/brambox/boxes/statistics/util.py +++ b/brambox/boxes/statistics/util.py @@ -109,7 +109,7 @@ def match_detections(detection_results, ground_truth, overlap_threshold, overlap positives.append((detection.confidence, False)) # sort matches by confidence from high to low - positives = sorted(positives, key=lambda d: d[0], reverse=True) + positives = sorted(positives, key=lambda d: (d[0], -d[1]), reverse=True) tps = [] fps = []