From ccc6a41ad3923bc7ae308749063f29a64b120ba2 Mon Sep 17 00:00:00 2001 From: joncrall Date: Sun, 3 Jun 2018 17:59:21 -0400 Subject: [PATCH] Removed non-determenistic behavior from match_detections --- brambox/boxes/statistics/pr.py | 8 ++------ brambox/boxes/statistics/util.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/brambox/boxes/statistics/pr.py b/brambox/boxes/statistics/pr.py index 1e68790..c2302f8 100644 --- a/brambox/boxes/statistics/pr.py +++ b/brambox/boxes/statistics/pr.py @@ -5,13 +5,9 @@ # # Functions for generating PR-curve values and calculating average precision # - -import math -from statistics import mean import numpy as np import scipy.interpolate - -from .util import * +from .util import match_detections __all__ = ['pr', 'ap'] @@ -54,7 +50,7 @@ def ap(precision, recall, num_of_samples=100): p = np.array(precision) r = np.array(recall) p_start = p[np.argmin(r)] - samples = np.arange(0., 1., 1.0/num_of_samples) + samples = np.arange(0., 1., 1.0 / num_of_samples) interpolated = scipy.interpolate.interp1d(r, p, fill_value=(p_start, 0.), bounds_error=False)(samples) avg = sum(interpolated) / len(interpolated) elif len(precision) > 0 and len(recall) > 0: diff --git a/brambox/boxes/statistics/util.py b/brambox/boxes/statistics/util.py index ddf2f63..7dea759 100644 --- a/brambox/boxes/statistics/util.py +++ b/brambox/boxes/statistics/util.py @@ -109,7 +109,7 @@ def match_detections(detection_results, ground_truth, overlap_threshold, overlap positives.append((detection.confidence, False)) # sort matches by confidence from high to low - positives = sorted(positives, key=lambda d: d[0], reverse=True) + positives = sorted(positives, key=lambda d: (d[0], -d[1]), reverse=True) tps = [] fps = [] -- GitLab