// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_
#define DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_
#include "structural_object_detection_trainer_abstract.h"
#include "../algs.h"
#include "../optimization.h"
#include "structural_svm_object_detection_problem.h"
#include "../image_processing/object_detector.h"
#include "../image_processing/box_overlap_testing.h"
#include "../image_processing/full_object_detection.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename image_scanner_type,
typename svm_struct_prob_type
>
void configure_nuclear_norm_regularizer (
const image_scanner_type&,
svm_struct_prob_type&
)
{
// does nothing by default. Specific scanner types overload this function to do
// whatever is appropriate.
}
// ----------------------------------------------------------------------------------------
template <
typename image_scanner_type
>
class structural_object_detection_trainer : noncopyable
{
public:
typedef double scalar_type;
typedef default_memory_manager mem_manager_type;
typedef object_detector<image_scanner_type> trained_function_type;
explicit structural_object_detection_trainer (
const image_scanner_type& scanner_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(scanner_.get_num_detection_templates() > 0,
"\t structural_object_detection_trainer::structural_object_detection_trainer(scanner_)"
<< "\n\t You can't have zero detection templates"
<< "\n\t this: " << this
);
C = 1;
verbose = false;
eps = 0.1;
num_threads = 2;
max_cache_size = 5;
match_eps = 0.5;
loss_per_missed_target = 1;
loss_per_false_alarm = 1;
scanner.copy_configuration(scanner_);
auto_overlap_tester = true;
}
const image_scanner_type& get_scanner (
) const
{
return scanner;
}
bool auto_set_overlap_tester (
) const
{
return auto_overlap_tester;
}
void set_overlap_tester (
const test_box_overlap& tester
)
{
overlap_tester = tester;
auto_overlap_tester = false;
}
test_box_overlap get_overlap_tester (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(auto_set_overlap_tester() == false,
"\t test_box_overlap structural_object_detection_trainer::get_overlap_tester()"
<< "\n\t You can't call this function if the overlap tester is generated dynamically."
<< "\n\t this: " << this
);
return overlap_tester;
}
void set_num_threads (
unsigned long num
)
{
num_threads = num;
}
unsigned long get_num_threads (
) const
{
return num_threads;
}
void set_epsilon (
scalar_type eps_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void structural_object_detection_trainer::set_epsilon()"
<< "\n\t eps_ must be greater than 0"
<< "\n\t eps_: " << eps_
<< "\n\t this: " << this
);
eps = eps_;
}
scalar_type get_epsilon (
) const { return eps; }
void set_max_runtime (
const std::chrono::nanoseconds& max_runtime
)
{
solver.set_max_runtime(max_runtime);
}
std::chrono::nanoseconds get_max_runtime (
) const
{
return solver.get_max_runtime();
}
void set_max_cache_size (
unsigned long max_size
)
{
max_cache_size = max_size;
}
unsigned long get_max_cache_size (
) const
{
return max_cache_size;
}
void be_verbose (
)
{
verbose = true;
}
void be_quiet (
)
{
verbose = false;
}
void set_oca (
const oca& item
)
{
solver = item;
}
const oca get_oca (
) const
{
return solver;
}
void set_c (
scalar_type C_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(C_ > 0,
"\t void structural_object_detection_trainer::set_c()"
<< "\n\t C_ must be greater than 0"
<< "\n\t C_: " << C_
<< "\n\t this: " << this
);
C = C_;
}
scalar_type get_c (
) const
{
return C;
}
void set_match_eps (
double eps
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < eps && eps < 1,
"\t void structural_object_detection_trainer::set_match_eps(eps)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t eps: " << eps
<< "\n\t this: " << this
);
match_eps = eps;
}
double get_match_eps (
) const
{
return match_eps;
}
double get_loss_per_missed_target (
) const
{
return loss_per_missed_target;
}
void set_loss_per_missed_target (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss > 0,
"\t void structural_object_detection_trainer::set_loss_per_missed_target(loss)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_missed_target = loss;
}
double get_loss_per_false_alarm (
) const
{
return loss_per_false_alarm;
}
void set_loss_per_false_alarm (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss > 0,
"\t void structural_object_detection_trainer::set_loss_per_false_alarm(loss)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_false_alarm = loss;
}
template <
typename image_array_type
>
const trained_function_type train (
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_object_detections
) const
{
std::vector<std::vector<rectangle> > empty_ignore(images.size());
return train_impl(images, truth_object_detections, empty_ignore, test_box_overlap());
}
template <
typename image_array_type
>
const trained_function_type train (
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_object_detections,
const std::vector<std::vector<rectangle> >& ignore,
const test_box_overlap& ignore_overlap_tester = test_box_overlap()
) const
{
return train_impl(images, truth_object_detections, ignore, ignore_overlap_tester);
}
template <
typename image_array_type
>
const trained_function_type train (
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_object_detections
) const
{
std::vector<std::vector<rectangle> > empty_ignore(images.size());
return train(images, truth_object_detections, empty_ignore, test_box_overlap());
}
template <
typename image_array_type
>
const trained_function_type train (
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_object_detections,
const std::vector<std::vector<rectangle> >& ignore,
const test_box_overlap& ignore_overlap_tester = test_box_overlap()
) const
{
std::vector<std::vector<full_object_detection> > truth_dets(truth_object_detections.size());
for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
{
for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
{
truth_dets[i].push_back(full_object_detection(truth_object_detections[i][j]));
}
}
return train_impl(images, truth_dets, ignore, ignore_overlap_tester);
}
private:
template <
typename image_array_type
>
const trained_function_type train_impl (
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_object_detections,
const std::vector<std::vector<rectangle> >& ignore,
const test_box_overlap& ignore_overlap_tester
) const
{
#ifdef ENABLE_ASSERTS
// make sure requires clause is not broken
DLIB_ASSERT(is_learning_problem(images,truth_object_detections) == true && images.size() == ignore.size(),
"\t trained_function_type structural_object_detection_trainer::train()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t images.size(): " << images.size()
<< "\n\t ignore.size(): " << ignore.size()
<< "\n\t truth_object_detections.size(): " << truth_object_detections.size()
<< "\n\t is_learning_problem(images,truth_object_detections): " << is_learning_problem(images,truth_object_detections)
);
for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
{
for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
{
DLIB_ASSERT(truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template() &&
all_parts_in_rect(truth_object_detections[i][j]) == true,
"\t trained_function_type structural_object_detection_trainer::train()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t truth_object_detections["<<i<<"]["<<j<<"].num_parts(): " <<
truth_object_detections[i][j].num_parts()
<< "\n\t get_scanner().get_num_movable_components_per_detection_template(): " <<
get_scanner().get_num_movable_components_per_detection_template()
<< "\n\t all_parts_in_rect(truth_object_detections["<<i<<"]["<<j<<"]): " << all_parts_in_rect(truth_object_detections[i][j])
);
}
}
#endif
structural_svm_object_detection_problem<image_scanner_type,image_array_type >
svm_prob(scanner, overlap_tester, auto_overlap_tester, images,
truth_object_detections, ignore, ignore_overlap_tester, num_threads);
if (verbose)
svm_prob.be_verbose();
svm_prob.set_c(C);
svm_prob.set_epsilon(eps);
svm_prob.set_max_cache_size(max_cache_size);
svm_prob.set_match_eps(match_eps);
svm_prob.set_loss_per_missed_target(loss_per_missed_target);
svm_prob.set_loss_per_false_alarm(loss_per_false_alarm);
configure_nuclear_norm_regularizer(scanner, svm_prob);
matrix<double,0,1> w;
// Run the optimizer to find the optimal w.
solver(svm_prob,w);
// report the results of the training.
return object_detector<image_scanner_type>(scanner, svm_prob.get_overlap_tester(), w);
}
image_scanner_type scanner;
test_box_overlap overlap_tester;
double C;
oca solver;
double eps;
double match_eps;
bool verbose;
unsigned long num_threads;
unsigned long max_cache_size;
double loss_per_missed_target;
double loss_per_false_alarm;
bool auto_overlap_tester;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_