// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_FIND_MAX_PaRSE_CKY_Hh_
#define DLIB_FIND_MAX_PaRSE_CKY_Hh_
#include "find_max_parse_cky_abstract.h"
#include <vector>
#include <string>
#include <sstream>
#include "../serialize.h"
#include "../array2d.h"
namespace dlib
{
// -----------------------------------------------------------------------------------------
template <typename T>
struct constituent
{
unsigned long begin, end, k;
T left_tag;
T right_tag;
};
template <typename T>
void serialize(
const constituent<T>& item,
std::ostream& out
)
{
serialize(item.begin, out);
serialize(item.end, out);
serialize(item.k, out);
serialize(item.left_tag, out);
serialize(item.right_tag, out);
}
template <typename T>
void deserialize(
constituent<T>& item,
std::istream& in
)
{
deserialize(item.begin, in);
deserialize(item.end, in);
deserialize(item.k, in);
deserialize(item.left_tag, in);
deserialize(item.right_tag, in);
}
// -----------------------------------------------------------------------------------------
const unsigned long END_OF_TREE = 0xFFFFFFFF;
// -----------------------------------------------------------------------------------------
template <typename T>
struct parse_tree_element
{
constituent<T> c;
T tag; // id for the constituent corresponding to this level of the tree
unsigned long left;
unsigned long right;
double score;
};
template <typename T>
void serialize (
const parse_tree_element<T>& item,
std::ostream& out
)
{
serialize(item.c, out);
serialize(item.tag, out);
serialize(item.left, out);
serialize(item.right, out);
serialize(item.score, out);
}
template <typename T>
void deserialize (
parse_tree_element<T>& item,
std::istream& in
)
{
deserialize(item.c, in);
deserialize(item.tag, in);
deserialize(item.left, in);
deserialize(item.right, in);
deserialize(item.score, in);
}
// -----------------------------------------------------------------------------------------
namespace impl
{
template <typename T>
unsigned long fill_parse_tree(
std::vector<parse_tree_element<T> >& parse_tree,
const T& tag,
const array2d<std::map<T, parse_tree_element<T> > >& back,
long r, long c
)
/*!
requires
- back[r][c].size() == 0 || back[r][c].count(tag) != 0
!*/
{
// base case of the recursion
if (back[r][c].size() == 0)
{
return END_OF_TREE;
}
const unsigned long idx = parse_tree.size();
const parse_tree_element<T>& item = back[r][c].find(tag)->second;
parse_tree.push_back(item);
const long k = item.c.k;
const unsigned long idx_left = fill_parse_tree(parse_tree, item.c.left_tag, back, r, k-1);
const unsigned long idx_right = fill_parse_tree(parse_tree, item.c.right_tag, back, k, c);
parse_tree[idx].left = idx_left;
parse_tree[idx].right = idx_right;
return idx;
}
}
template <typename T, typename production_rule_function>
void find_max_parse_cky (
const std::vector<T>& sequence,
const production_rule_function& production_rules,
std::vector<parse_tree_element<T> >& parse_tree
)
{
parse_tree.clear();
if (sequence.size() == 0)
return;
array2d<std::map<T,double> > table(sequence.size(), sequence.size());
array2d<std::map<T,parse_tree_element<T> > > back(sequence.size(), sequence.size());
typedef typename std::map<T,double>::iterator itr;
typedef typename std::map<T,parse_tree_element<T> >::iterator itr_b;
for (long r = 0; r < table.nr(); ++r)
table[r][r][sequence[r]] = 0;
std::vector<std::pair<T,double> > possible_tags;
for (long r = table.nr()-2; r >= 0; --r)
{
for (long c = r+1; c < table.nc(); ++c)
{
for (long k = r; k < c; ++k)
{
for (itr i = table[k+1][c].begin(); i != table[k+1][c].end(); ++i)
{
for (itr j = table[r][k].begin(); j != table[r][k].end(); ++j)
{
constituent<T> con;
con.begin = r;
con.end = c+1;
con.k = k+1;
con.left_tag = j->first;
con.right_tag = i->first;
possible_tags.clear();
production_rules(sequence, con, possible_tags);
for (unsigned long m = 0; m < possible_tags.size(); ++m)
{
const double score = possible_tags[m].second + i->second + j->second;
itr match = table[r][c].find(possible_tags[m].first);
if (match == table[r][c].end() || score > match->second)
{
table[r][c][possible_tags[m].first] = score;
parse_tree_element<T> item;
item.c = con;
item.score = score;
item.tag = possible_tags[m].first;
item.left = END_OF_TREE;
item.right = END_OF_TREE;
back[r][c][possible_tags[m].first] = item;
}
}
}
}
}
}
}
// now use back pointers to build the parse trees
const long r = 0;
const long c = back.nc()-1;
if (back[r][c].size() != 0)
{
// find the max scoring element in back[r][c]
itr_b max_i = back[r][c].begin();
itr_b i = max_i;
++i;
for (; i != back[r][c].end(); ++i)
{
if (i->second.score > max_i->second.score)
max_i = i;
}
parse_tree.reserve(c);
impl::fill_parse_tree(parse_tree, max_i->second.tag, back, r, c);
}
}
// -----------------------------------------------------------------------------------------
class parse_tree_to_string_error : public error
{
public:
parse_tree_to_string_error(const std::string& str): error(str) {}
};
namespace impl
{
template <bool enabled, typename T>
typename enable_if_c<enabled>::type conditional_print(
const T& item,
std::ostream& out
) { out << item << " "; }
template <bool enabled, typename T>
typename disable_if_c<enabled>::type conditional_print(
const T& ,
std::ostream&
) { }
template <bool print_tag, bool skip_tag, typename T, typename U >
void print_parse_tree_helper (
const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& words,
unsigned long i,
const T& tag_to_skip,
std::ostream& out
)
{
if (!skip_tag || tree[i].tag != tag_to_skip)
out << "[";
bool left_recurse = false;
// Only print if we are supposed to. Doing it this funny way avoids compiler
// errors in parse_tree_to_string() for the case where tag isn't
// printable.
if (!skip_tag || tree[i].tag != tag_to_skip)
conditional_print<print_tag>(tree[i].tag, out);
if (tree[i].left < tree.size())
{
left_recurse = true;
print_parse_tree_helper<print_tag,skip_tag>(tree, words, tree[i].left, tag_to_skip, out);
}
else
{
if ((tree[i].c.begin) < words.size())
{
out << words[tree[i].c.begin] << " ";
}
else
{
std::ostringstream sout;
sout << "Parse tree refers to element " << tree[i].c.begin
<< " of sequence which is only of size " << words.size() << ".";
throw parse_tree_to_string_error(sout.str());
}
}
if (left_recurse == true)
out << " ";
if (tree[i].right < tree.size())
{
print_parse_tree_helper<print_tag,skip_tag>(tree, words, tree[i].right, tag_to_skip, out);
}
else
{
if (tree[i].c.k < words.size())
{
out << words[tree[i].c.k];
}
else
{
std::ostringstream sout;
sout << "Parse tree refers to element " << tree[i].c.k
<< " of sequence which is only of size " << words.size() << ".";
throw parse_tree_to_string_error(sout.str());
}
}
if (!skip_tag || tree[i].tag != tag_to_skip)
out << "]";
}
}
// -----------------------------------------------------------------------------------------
template <typename T, typename U>
std::string parse_tree_to_string (
const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& words,
const unsigned long root_idx = 0
)
{
if (root_idx >= tree.size())
return "";
std::ostringstream sout;
impl::print_parse_tree_helper<false,false>(tree, words, root_idx, tree[root_idx].tag, sout);
return sout.str();
}
// -----------------------------------------------------------------------------------------
template <typename T, typename U>
std::string parse_tree_to_string_tagged (
const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& words,
const unsigned long root_idx = 0
)
{
if (root_idx >= tree.size())
return "";
std::ostringstream sout;
impl::print_parse_tree_helper<true,false>(tree, words, root_idx, tree[root_idx].tag, sout);
return sout.str();
}
// -----------------------------------------------------------------------------------------
template <typename T, typename U>
std::string parse_trees_to_string (
const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& words,
const T& tag_to_skip
)
{
if (tree.size() == 0)
return "";
std::ostringstream sout;
impl::print_parse_tree_helper<false,true>(tree, words, 0, tag_to_skip, sout);
return sout.str();
}
// -----------------------------------------------------------------------------------------
template <typename T, typename U>
std::string parse_trees_to_string_tagged (
const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& words,
const T& tag_to_skip
)
{
if (tree.size() == 0)
return "";
std::ostringstream sout;
impl::print_parse_tree_helper<true,true>(tree, words, 0, tag_to_skip, sout);
return sout.str();
}
// -----------------------------------------------------------------------------------------
namespace impl
{
template <typename T>
void helper_find_trees_without_tag (
const std::vector<parse_tree_element<T> >& tree,
const T& tag,
std::vector<unsigned long>& tree_roots,
unsigned long idx
)
{
if (idx < tree.size())
{
if (tree[idx].tag != tag)
{
tree_roots.push_back(idx);
}
else
{
helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].left);
helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].right);
}
}
}
}
template <typename T>
void find_trees_not_rooted_with_tag (
const std::vector<parse_tree_element<T> >& tree,
const T& tag,
std::vector<unsigned long>& tree_roots
)
{
tree_roots.clear();
impl::helper_find_trees_without_tag(tree, tag, tree_roots, 0);
}
// -----------------------------------------------------------------------------------------
}
#endif // DLIB_FIND_MAX_PaRSE_CKY_Hh_