// Copyright (C) 2014 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_
#define DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_
#include "cross_validate_track_association_trainer_abstract.h"
#include "structural_track_association_trainer.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace impl
{
template <
typename track_association_function,
typename detection_type,
typename label_type
>
void test_track_association_function (
const track_association_function& assoc,
const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& samples,
unsigned long& total_dets,
unsigned long& correctly_associated_dets
)
{
const typename track_association_function::association_function_type& f = assoc.get_assignment_function();
typedef typename detection_type::track_type track_type;
using namespace impl;
dlib::rand rnd;
std::vector<track_type> tracks;
std::map<label_type,long> track_idx; // tracks[track_idx[id]] == track with ID id.
for (unsigned long j = 0; j < samples.size(); ++j)
{
std::vector<labeled_detection<detection_type,label_type> > dets = samples[j];
// Shuffle the order of the detections so we can be sure that there isn't
// anything funny going on like the detections always coming in the same
// order relative to their labels and the association function just gets
// lucky by picking the same assignment ordering every time. So this way
// we know the assignment function really is doing something rather than
// just being lucky.
randomize_samples(dets, rnd);
total_dets += dets.size();
std::vector<long> assignments = f(get_unlabeled_dets(dets), tracks);
std::vector<bool> updated_track(tracks.size(), false);
// now update all the tracks with the detections that associated to them.
for (unsigned long k = 0; k < assignments.size(); ++k)
{
// If the detection is associated to tracks[assignments[k]]
if (assignments[k] != -1)
{
tracks[assignments[k]].update_track(dets[k].det);
updated_track[assignments[k]] = true;
// if this detection was supposed to go to this track
if (track_idx.count(dets[k].label) && track_idx[dets[k].label]==assignments[k])
++correctly_associated_dets;
track_idx[dets[k].label] = assignments[k];
}
else
{
track_type new_track;
new_track.update_track(dets[k].det);
tracks.push_back(new_track);
// if this detection was supposed to go to a new track
if (track_idx.count(dets[k].label) == 0)
++correctly_associated_dets;
track_idx[dets[k].label] = tracks.size()-1;
}
}
// Now propagate all the tracks that didn't get any detections.
for (unsigned long k = 0; k < updated_track.size(); ++k)
{
if (!updated_track[k])
tracks[k].propagate_track();
}
}
}
}
// ----------------------------------------------------------------------------------------
template <
typename track_association_function,
typename detection_type,
typename label_type
>
double test_track_association_function (
const track_association_function& assoc,
const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
)
{
unsigned long total_dets = 0;
unsigned long correctly_associated_dets = 0;
for (unsigned long i = 0; i < samples.size(); ++i)
{
impl::test_track_association_function(assoc, samples[i], total_dets, correctly_associated_dets);
}
return (double)correctly_associated_dets/(double)total_dets;
}
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename detection_type,
typename label_type
>
double cross_validate_track_association_trainer (
const trainer_type& trainer,
const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples,
const long folds
)
{
const long num_in_test = samples.size()/folds;
const long num_in_train = samples.size() - num_in_test;
std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > > samples_train;
long next_test_idx = 0;
unsigned long total_dets = 0;
unsigned long correctly_associated_dets = 0;
for (long i = 0; i < folds; ++i)
{
samples_train.clear();
// load up the training samples
long next = (next_test_idx + num_in_test)%samples.size();
for (long cnt = 0; cnt < num_in_train; ++cnt)
{
samples_train.push_back(samples[next]);
next = (next + 1)%samples.size();
}
const track_association_function<detection_type>& df = trainer.train(samples_train);
for (long cnt = 0; cnt < num_in_test; ++cnt)
{
impl::test_track_association_function(df, samples[next_test_idx], total_dets, correctly_associated_dets);
next_test_idx = (next_test_idx + 1)%samples.size();
}
}
return (double)correctly_associated_dets/(double)total_dets;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_