// Copyright (C) 2009 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <dlib/type_safe_union.h>
#include "tester.h"
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.type_safe_union");
struct can_not_copy: noncopyable {};
void serialize(const can_not_copy&, std::ostream&) {}
void deserialize(can_not_copy&, std::istream&) {}
void swap(can_not_copy&, can_not_copy&) {}
class test
{
private:
enum kind
{
FLOAT, DOUBLE, CHAR, STRING, NONE
};
void operator() (float val)
{
DLIB_TEST(val == f_val);
last_kind = FLOAT;
}
void operator() (double val)
{
DLIB_TEST(val == d_val);
last_kind = DOUBLE;
}
void operator() (char val)
{
DLIB_TEST(val == c_val);
last_kind = CHAR;
}
void operator()(std::string& val)
{
DLIB_TEST(val == s_val);
last_kind = STRING;
}
void operator()(const std::string& val)
{
DLIB_TEST(val == s_val);
last_kind = STRING;
}
// ------------------------------
friend class type_safe_union<float, double, char, std::string>;
typedef type_safe_union<float, double, char, std::string> tsu;
tsu a, b, c;
float f_val;
double d_val;
char c_val;
std::string s_val;
kind last_kind;
public:
void test_stuff()
{
static_assert(tsu::get_type_id<float>() == 1, "bad type id");
static_assert(tsu::get_type_id<double>() == 2, "bad type id");
static_assert(tsu::get_type_id<char>() == 3, "bad type id");
static_assert(tsu::get_type_id<std::string>() == 4, "bad type id");
static_assert(tsu::get_type_id<long>() == -1, "This should be -1");
DLIB_TEST(a.is_empty() == true);
DLIB_TEST(a.contains<char>() == false);
DLIB_TEST(a.contains<float>() == false);
DLIB_TEST(a.contains<double>() == false);
DLIB_TEST(a.contains<std::string>() == false);
DLIB_TEST(a.contains<long>() == false);
DLIB_TEST(a.get_type_id<int>() == -1);
DLIB_TEST(a.get_type_id<float>() == 1);
DLIB_TEST(a.get_type_id<double>() == 2);
DLIB_TEST(a.get_type_id<char>() == 3);
DLIB_TEST(a.get_type_id<std::string>() == 4);
DLIB_TEST(a.get_type_id<tsu>() == -1);
f_val = 4.345f;
a.get<float>() = f_val;
DLIB_TEST(a.cast_to<float>() == f_val);
DLIB_TEST(const_cast<const tsu&>(a).cast_to<float>() == f_val);
bool exception_thrown = false;
try {a.cast_to<char>(); }
catch (bad_type_safe_union_cast&) { exception_thrown = true;}
DLIB_TEST(exception_thrown);
DLIB_TEST(a.is_empty() == false);
DLIB_TEST(a.contains<char>() == false);
DLIB_TEST(a.contains<float>() == true);
DLIB_TEST(a.contains<double>() == false);
DLIB_TEST(a.contains<std::string>() == false);
DLIB_TEST(a.contains<long>() == false);
last_kind = NONE;
const_cast<const tsu&>(a).apply_to_contents(*this);
DLIB_TEST(last_kind == FLOAT);
// -----------
d_val = 4.345;
a.get<double>() = d_val;
last_kind = NONE;
a.apply_to_contents(*this);
DLIB_TEST(last_kind == DOUBLE);
// -----------
c_val = 'a';
a.get<char>() = c_val;
last_kind = NONE;
const_cast<const tsu&>(a).apply_to_contents(*this);
DLIB_TEST(last_kind == CHAR);
// -----------
s_val = "test string";
a.get<std::string>() = s_val;
last_kind = NONE;
a.apply_to_contents(*this);
DLIB_TEST(last_kind == STRING);
DLIB_TEST(a.cast_to<std::string>() == s_val);
exception_thrown = false;
try {a.cast_to<float>(); }
catch (bad_type_safe_union_cast&) { exception_thrown = true;}
DLIB_TEST(exception_thrown);
// -----------
DLIB_TEST(a.is_empty() == false);
DLIB_TEST(a.contains<char>() == false);
DLIB_TEST(a.contains<float>() == false);
DLIB_TEST(a.contains<double>() == false);
DLIB_TEST(a.contains<std::string>() == true);
DLIB_TEST(a.contains<long>() == false);
// -----------
a.swap(b);
DLIB_TEST(a.is_empty() == true);
DLIB_TEST(a.contains<char>() == false);
DLIB_TEST(a.contains<float>() == false);
DLIB_TEST(a.contains<double>() == false);
DLIB_TEST(a.contains<std::string>() == false);
DLIB_TEST(a.contains<long>() == false);
DLIB_TEST(b.is_empty() == false);
DLIB_TEST(b.contains<char>() == false);
DLIB_TEST(b.contains<float>() == false);
DLIB_TEST(b.contains<double>() == false);
DLIB_TEST(b.contains<std::string>() == true);
DLIB_TEST(b.contains<long>() == false);
last_kind = NONE;
b.apply_to_contents(*this);
DLIB_TEST(last_kind == STRING);
// -----------
b.swap(a);
DLIB_TEST(b.is_empty() == true);
DLIB_TEST(b.contains<char>() == false);
DLIB_TEST(b.contains<float>() == false);
DLIB_TEST(b.contains<double>() == false);
DLIB_TEST(b.contains<std::string>() == false);
DLIB_TEST(b.contains<long>() == false);
DLIB_TEST(a.is_empty() == false);
DLIB_TEST(a.contains<char>() == false);
DLIB_TEST(a.contains<float>() == false);
DLIB_TEST(a.contains<double>() == false);
DLIB_TEST(a.contains<std::string>() == true);
DLIB_TEST(a.contains<long>() == false);
last_kind = NONE;
a.apply_to_contents(*this);
DLIB_TEST(last_kind == STRING);
last_kind = NONE;
b.apply_to_contents(*this);
DLIB_TEST(last_kind == NONE);
a.get<char>() = 'a';
b.get<char>() = 'b';
DLIB_TEST(a.is_empty() == false);
DLIB_TEST(a.contains<char>() == true);
DLIB_TEST(b.is_empty() == false);
DLIB_TEST(b.contains<char>() == true);
DLIB_TEST(a.contains<float>() == false);
DLIB_TEST(b.contains<float>() == false);
DLIB_TEST(a.get<char>() == 'a');
DLIB_TEST(b.get<char>() == 'b');
swap(a,b);
DLIB_TEST(a.is_empty() == false);
DLIB_TEST(a.contains<char>() == true);
DLIB_TEST(b.is_empty() == false);
DLIB_TEST(b.contains<char>() == true);
DLIB_TEST(a.contains<float>() == false);
DLIB_TEST(b.contains<float>() == false);
DLIB_TEST(a.get<char>() == 'b');
DLIB_TEST(b.get<char>() == 'a');
// -----------
a.get<char>() = 'a';
b.get<std::string>() = "a string";
DLIB_TEST(a.is_empty() == false);
DLIB_TEST(a.contains<char>() == true);
DLIB_TEST(b.is_empty() == false);
DLIB_TEST(b.contains<char>() == false);
DLIB_TEST(a.contains<std::string>() == false);
DLIB_TEST(b.contains<std::string>() == true);
DLIB_TEST(a.get<char>() == 'a');
DLIB_TEST(b.get<std::string>() == "a string");
swap(a,b);
DLIB_TEST(b.is_empty() == false);
DLIB_TEST(b.contains<char>() == true);
DLIB_TEST(a.is_empty() == false);
DLIB_TEST(a.contains<char>() == false);
DLIB_TEST(b.contains<std::string>() == false);
DLIB_TEST(a.contains<std::string>() == true);
DLIB_TEST(b.get<char>() == 'a');
DLIB_TEST(a.get<std::string>() == "a string");
{
type_safe_union<char, float, std::string> a, b, empty_union;
ostringstream sout;
istringstream sin;
a.get<char>() = 'd';
serialize(a, sout);
sin.str(sout.str());
deserialize(b, sin);
DLIB_TEST(b.contains<int>() == false);
DLIB_TEST(b.contains<float>() == false);
DLIB_TEST(b.contains<char>() == true);
DLIB_TEST(b.get<char>() == 'd');
DLIB_TEST(a.contains<int>() == false);
DLIB_TEST(a.contains<float>() == false);
DLIB_TEST(a.contains<char>() == true);
DLIB_TEST(a.get<char>() == 'd');
sin.clear();
sout.clear();
sout.str("");
a.get<std::string>() = "davis";
serialize(a, sout);
sin.str(sout.str());
deserialize(b, sin);
DLIB_TEST(b.contains<int>() == false);
DLIB_TEST(b.contains<float>() == false);
DLIB_TEST(b.contains<std::string>() == true);
DLIB_TEST(b.get<std::string>() == "davis");
sin.clear();
sout.clear();
sout.str("");
serialize(empty_union, sout);
sin.str(sout.str());
deserialize(b, sin);
DLIB_TEST(b.is_empty() == true);
}
{
type_safe_union<char, float, std::string> a, b, empty_union;
ostringstream sout;
istringstream sin;
a = 'd';
serialize(a, sout);
sin.str(sout.str());
deserialize(b, sin);
DLIB_TEST(b.contains<int>() == false);
DLIB_TEST(b.contains<float>() == false);
DLIB_TEST(b.contains<char>() == true);
DLIB_TEST(b.get<char>() == 'd');
DLIB_TEST(a.contains<int>() == false);
DLIB_TEST(a.contains<float>() == false);
DLIB_TEST(a.contains<char>() == true);
DLIB_TEST(a.get<char>() == 'd');
sin.clear();
sout.clear();
sout.str("");
a = std::string("davis");
serialize(a, sout);
sin.str(sout.str());
deserialize(b, sin);
DLIB_TEST(b.contains<int>() == false);
DLIB_TEST(b.contains<float>() == false);
DLIB_TEST(b.contains<std::string>() == true);
DLIB_TEST(b.get<std::string>() == "davis");
sin.clear();
sout.clear();
sout.str("");
serialize(empty_union, sout);
sin.str(sout.str());
deserialize(b, sin);
DLIB_TEST(b.is_empty() == true);
}
{
typedef type_safe_union<char, float, std::string, can_not_copy> tsu_type;
tsu_type a('d'), aa(std::string("davis")), b, empty_union;
ostringstream sout;
istringstream sin;
serialize(a, sout);
sin.str(sout.str());
deserialize(b, sin);
DLIB_TEST(b.contains<int>() == false);
DLIB_TEST(b.contains<float>() == false);
DLIB_TEST(b.contains<char>() == true);
DLIB_TEST(b.get<char>() == 'd');
DLIB_TEST(a.contains<int>() == false);
DLIB_TEST(a.contains<float>() == false);
DLIB_TEST(a.contains<char>() == true);
DLIB_TEST(a.get<char>() == 'd');
DLIB_TEST(aa.contains<int>() == false);
DLIB_TEST(aa.contains<float>() == false);
DLIB_TEST(aa.contains<char>() == false);
DLIB_TEST(aa.contains<std::string>() == true);
sin.clear();
sout.clear();
sout.str("");
serialize(aa, sout);
sin.str(sout.str());
deserialize(b, sin);
DLIB_TEST(b.contains<int>() == false);
DLIB_TEST(b.contains<float>() == false);
DLIB_TEST(b.contains<std::string>() == true);
DLIB_TEST(b.get<std::string>() == "davis");
sin.clear();
sout.clear();
sout.str("");
serialize(empty_union, sout);
sin.str(sout.str());
deserialize(b, sin);
DLIB_TEST(b.is_empty() == true);
a.get<can_not_copy>();
DLIB_TEST(a.contains<can_not_copy>() == true);
}
{
type_safe_union<int,std::string> a, b;
a = std::string("asdf");
b = 3;
b = std::move(a);
DLIB_TEST(b.get<std::string>() == "asdf");
}
{
type_safe_union<int,std::string> a = 3;
type_safe_union<int,std::string> b = std::string("asdf");
DLIB_TEST(a.get<int>() == 3);
DLIB_TEST(b.get<std::string>() == "asdf");
}
{
using ptr_t = std::unique_ptr<std::string>;
type_safe_union<int, ptr_t> a;
type_safe_union<int, ptr_t> b = ptr_t(new std::string("asdf"));
a = std::move(b);
DLIB_TEST(a.contains<ptr_t>());
DLIB_TEST(!b.contains<ptr_t>());
DLIB_TEST(*a.get<ptr_t>() == "asdf");
swap(a,b);
DLIB_TEST(b.contains<ptr_t>());
DLIB_TEST(!a.contains<ptr_t>());
DLIB_TEST(*b.get<ptr_t>() == "asdf");
}
{
//testing copy semantics and move semantics
struct mytype
{
mytype(int i_ = 0) : i(i_) {}
mytype(const mytype& other) : i(other.i) {}
mytype& operator=(const mytype& other) {i = other.i; return *this;}
mytype(mytype&& other) : i(other.i) {other.i = 0;}
mytype& operator=(mytype&& other) {i = other.i ; other.i = 0; return *this;}
int i = 0;
};
using tsu = type_safe_union<int,mytype>;
{
mytype a(10);
tsu ta(a); //copy constructor
DLIB_TEST(a.i == 10);
DLIB_TEST(ta.cast_to<mytype>().i == 10);
}
{
mytype a(10);
tsu ta;
ta = a; //copy assign
DLIB_TEST(a.i == 10);
DLIB_TEST(ta.cast_to<mytype>().i == 10);
}
{
mytype a(10);
tsu ta(std::move(a)); //move constructor
DLIB_TEST(a.i == 0);
DLIB_TEST(ta.cast_to<mytype>().i == 10);
}
{
mytype a(10);
tsu ta;
ta = std::move(a); //move assign
DLIB_TEST(a.i == 0);
DLIB_TEST(ta.cast_to<mytype>().i == 10);
}
{
tsu ta(mytype(10));
DLIB_TEST(ta.cast_to<mytype>().i == 10);
tsu tb(ta); //copy constructor
DLIB_TEST(ta.cast_to<mytype>().i == 10);
DLIB_TEST(tb.cast_to<mytype>().i == 10);
}
{
tsu ta(mytype(10));
DLIB_TEST(ta.cast_to<mytype>().i == 10);
tsu tb;
tb = ta; //copy assign
DLIB_TEST(ta.cast_to<mytype>().i == 10);
DLIB_TEST(tb.cast_to<mytype>().i == 10);
}
{
tsu ta(mytype(10));
DLIB_TEST(ta.cast_to<mytype>().i == 10);
tsu tb(std::move(ta)); //move constructor
DLIB_TEST(ta.is_empty());
DLIB_TEST(tb.cast_to<mytype>().i == 10);
}
{
tsu ta(mytype(10));
DLIB_TEST(ta.cast_to<mytype>().i == 10);
tsu tb;
tb = std::move(ta); //move assign
DLIB_TEST(ta.is_empty());
DLIB_TEST(tb.cast_to<mytype>().i == 10);
}
}
{
//testing emplace(), copy semantics, move semantics, swap, overloaded, and new visitor
type_safe_union<int, float, std::string> a, b;
a.emplace<std::string>("hello world");
DLIB_TEST(a.contains<std::string>());
b = a; //copy
DLIB_TEST(a.contains<std::string>());
DLIB_TEST(b.contains<std::string>());
DLIB_TEST(a.cast_to<std::string>() == "hello world");
DLIB_TEST(b.cast_to<std::string>() == "hello world");
a = 1;
DLIB_TEST(a.contains<int>());
DLIB_TEST(a.cast_to<int>() == 1);
b = std::move(a);
DLIB_TEST(b.contains<int>());
DLIB_TEST(b.cast_to<int>() == 1);
DLIB_TEST(a.is_empty());
DLIB_TEST(a.get_current_type_id() == 0);
swap(a, b);
DLIB_TEST(a.contains<int>());
DLIB_TEST(a.cast_to<int>() == 1);
DLIB_TEST(b.is_empty());
DLIB_TEST(b.get_current_type_id() == 0);
//visit can return non-void types
auto ret = visit(overloaded(
[](int) {
return std::string("int");
},
[](float) {
return std::string("float");
},
[](const std::string&) {
return std::string("std::string");
}
), a);
static_assert(std::is_same<std::string, decltype(ret)>::value, "bad return type");
DLIB_TEST(ret == "int");
//apply_to_contents can only return void
a = std::string("hello there!");
std::string str;
a.apply_to_contents(overloaded(
[&str](int) {
str = std::string("int");
},
[&str](float) {
str = std::string("float");
},
[&str](const std::string& item) {
str = item;
}
));
DLIB_TEST(str == "hello there!");
}
{
//nested unions
using tsu_a = type_safe_union<int,float,std::string>;
using tsu_b = type_safe_union<int,float,std::string,tsu_a>;
tsu_b object(dlib::in_place_tag<tsu_a>{}, std::string("hello from bottom node"));
DLIB_TEST(object.contains<tsu_a>());
DLIB_TEST(object.get<tsu_a>().get<std::string>() == "hello from bottom node");
auto ret = visit(overloaded(
[](int) {
return std::string("int");
},
[](float) {
return std::string("float");
},
[](std::string) {
return std::string("std::string");
},
[](const tsu_a& item) {
return visit( overloaded(
[](int) {
return std::string("nested int");
},
[](float) {
return std::string("nested float");
},
[](std::string str) {
return str;
}
), item);
}
), object);
static_assert(std::is_same<std::string, decltype(ret)>::value, "bad type");
DLIB_TEST(ret == "hello from bottom node");
}
{
//struct visitor
using tsu = type_safe_union<int,float,std::string>;
struct visitor_private
{
std::string operator()(int)
{
return std::string("int");
}
std::string operator()(float)
{
return std::string("float");
}
std::string operator()(const std::string& str)
{
return str;
}
};
visitor_private visitor;
tsu a = std::string("hello from private visitor");
auto ret = visit(visitor, a);
static_assert(std::is_same<std::string, decltype(ret)>::value, "bad type");
DLIB_TEST(ret == "hello from private visitor");
}
}
};
namespace test_for_each_1
{
/*! Local classes aren't allowed to have template member functions... !*/
using tsu = type_safe_union<int,float,std::string>;
static_assert(type_safe_union_size<tsu>::value == 3, "bad number of types");
static_assert(std::is_same<type_safe_union_alternative_t<0, tsu>, int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, tsu>, float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, tsu>, std::string>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<0, const tsu>, const int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, const tsu>, const float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, const tsu>, const std::string>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<0, volatile tsu>, volatile int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, volatile tsu>, volatile float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, volatile tsu>, volatile std::string>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<0, const volatile tsu>, const volatile int>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<1, const volatile tsu>, const volatile float>::value, "bad type");
static_assert(std::is_same<type_safe_union_alternative_t<2, const volatile tsu>, const volatile std::string>::value, "bad type");
struct for_each_visitor
{
std::vector<int> type_indices;
template<typename T>
void operator()(dlib::in_place_tag<T>, const tsu& item)
{
type_indices.push_back(item.get_type_id<T>());
}
};
void test()
{
tsu a;
for_each_visitor visitor;
for_each_type(visitor, a);
DLIB_TEST(visitor.type_indices.size() == 3);
DLIB_TEST(visitor.type_indices[0] == 1);
DLIB_TEST(visitor.type_indices[1] == 2);
DLIB_TEST(visitor.type_indices[2] == 3);
}
}
namespace test_for_each_2
{
/*! Local classes aren't allowed to have template member functions... !*/
using tsu = type_safe_union<int,float,std::string>;
//for_each() that demonstrates an actual use-case.
//Instead of something simple like a target index, you might want to set the variant
//based on some specific state that is unique to one of the alternative types.
//For example, you might want to conditionally set the variant based on hashes.
struct for_each_visitor
{
for_each_visitor(int target_index_) : target_index(target_index_) {}
template<typename TagType>
void operator()(TagType tag, tsu& item)
{
if (item.get_type_id<TagType>() == target_index)
item = tsu{tag};
}
const int target_index = 0;
};
void test()
{
tsu a;
for_each_type(for_each_visitor{1}, a);
DLIB_TEST(a.contains<int>());
a.clear();
for_each_type(for_each_visitor{2}, a);
DLIB_TEST(a.contains<float>());
a.clear();
for_each_type(for_each_visitor{3}, a);
DLIB_TEST(a.contains<std::string>());
a.clear();
for_each_type(for_each_visitor{-1}, a);
DLIB_TEST(a.is_empty());
a.clear();
for_each_type(for_each_visitor{215465}, a);
DLIB_TEST(a.is_empty());
}
}
namespace test_for_each_3
{
using tsu1 = type_safe_union<int,float,std::string>;
using tsu2 = type_safe_union<std::string,long,char>;
struct serializer_typeid
{
serializer_typeid(std::ostream& out_) : out(out_) {}
template<typename T>
void operator()(const T& x)
{
dlib::serialize(typeid(T).hash_code(), out);
dlib::serialize(x, out);
}
std::ostream& out;
};
struct deserializer_typeid
{
deserializer_typeid(std::istream& in_) : in(in_)
{
dlib::deserialize(hash_code, in);
}
template<typename T, typename TSU>
void operator()(in_place_tag<T>, TSU&& me)
{
if (typeid(T).hash_code() == hash_code)
dlib::deserialize(me.template get<T>(), in);
}
std::size_t hash_code = 0;
std::istream& in;
};
void test()
{
tsu1 a;
a.get<std::string>() = "hello from tsu1";
std::stringstream out;
visit(serializer_typeid(out), a);
tsu2 b;
for_each_type(deserializer_typeid(out), b);
DLIB_TEST(b.contains<std::string>());
DLIB_TEST(b.get<std::string>() == "hello from tsu1");
}
}
class type_safe_union_tester : public tester
{
public:
type_safe_union_tester (
) :
tester ("test_type_safe_union",
"Runs tests on the type_safe_union object")
{}
void perform_test (
)
{
for (int i = 0; i < 10; ++i)
{
test a;
a.test_stuff();
test_for_each_1::test();
test_for_each_2::test();
test_for_each_3::test();
}
}
} a;
}