mi024

College project "Projet IAD" master 1
git clone https://esimon.eu/repos/mi024.git
Log | Files | Refs | README

supervisor.cpp (10408B)


      1 /**
      2  * @file
      3  * @brief Interface to the training and error algorithms.
      4  */
      5 
      6 #include <iostream>
      7 #include <limits>
      8 #include <sstream>
      9 #include <stdexcept>
     10 #include <string>
     11 #include <vector>
     12 
     13 #include <boost/archive/binary_iarchive.hpp>
     14 #include <boost/archive/binary_oarchive.hpp>
     15 #include <boost/program_options.hpp>
     16 #include <boost/shared_ptr.hpp>
     17 #include <sqlite3.h>
     18 
     19 #include "data_set.hpp"
     20 #include "nmlp_iostream_exports.hpp"
     21 #include "sqlite_utils.hpp"
     22 #include "model.hpp"
     23 
     24 static char const * const usage = "usage:\n"
     25                                   "\tsupervisor --help\n"
     26                                   "\tsupervisor ( train | test ) -d database dataset configfile kind*\n"
     27                                   "\tsupervisor ( train | test ) --interactive -d database dataset kind*\n";
     28 
     29 /** @brief Fill a given variable_map with the variable in a given configuration file. */
     30 void detached_build_config(std::string const &configfile_path, boost::program_options::variables_map &vm);
     31 
     32 /** @brief Select a set of View interactively. */
     33 std::vector<std::string> interactive_select_views();
     34 
     35 /** @brief Select a set of View from a given configuration. */
     36 std::vector<std::string> detached_select_views(boost::program_options::variables_map const &vm);
     37 
     38 /** @brief Fetch the given View objects from a given database. */
     39 std::vector<boost::shared_ptr<View> > get_views(std::string const &database_path, std::vector<std::string> const &view_str);
     40 
     41 /** @brief Train a given Model, ask for the parameters interactively. */
     42 void interactive_train(Model &model, Data_set &data_set);
     43 
     44 /** @brief Train a given Model, get the parameters from a given configuration. */
     45 void detached_train(Model &model, Data_set &data_set, boost::program_options::variables_map const &vm);
     46 
     47 /** @brief Test a given Model, ask for the parameters interactively. */
     48 void interactive_test(Model &model, Data_set &data_set);
     49 
     50 /** @brief Test a given Model, get the parameters from a given configuration. */
     51 void detached_test(Model &model, Data_set &data_set, boost::program_options::variables_map const &vm);
     52 
     53 /** @brief Update a given View database with a set of newly trained View objects. */
     54 void save_views(std::string const &database_path, std::vector<boost::shared_ptr<View> > const &views);
     55 
     56 int main(int argc, char **argv){
     57 	namespace po = boost::program_options;
     58 	std::string database_path, dataset_path;
     59 	std::vector<std::string> kinds;
     60 
     61 	po::options_description desc("Allowed options");
     62 	desc.add_options()
     63 		("command,a", po::value<std::string>(), "Select the operation to execute")
     64 		("configfile,c", po::value<std::string>(), "Configuration file containing the training/testing parameters")
     65 		("dataset,s", po::value<std::string>(&dataset_path), "Data on which the model will be trained/tested")
     66 		("database,d", po::value<std::string>(&database_path), "Select the view database")
     67 		("help,h", "Print a summary of the options")
     68 		("interactive,i", "Start in interactive mode")
     69 		("kind,k", po::value<std::vector<std::string> >(&kinds), "Set the kinds present in the dataset")
     70 		;
     71 
     72 	po::positional_options_description p;
     73 	p.add("command", 1).add("dataset", 1).add("configfile", 1).add("kind", -1);
     74 
     75 	po::variables_map vm, vm_configfile;
     76 	po::store(po::command_line_parser(argc, argv).options(desc).positional(p).run(), vm);
     77 	po::notify(vm);
     78 
     79 	if(vm.count("help")){
     80 		std::cout << usage << "\n" << desc << "\n";
     81 		return 0;
     82 	}
     83 
     84 	if(!vm.count("database") || !vm.count("command")){
     85 		std::cout << usage;
     86 		return 1;
     87 	}
     88 
     89 	if(vm.count("interactive")){
     90 		if(vm.count("configfile"))
     91 			kinds.push_back(vm["configfile"].as<std::string>());
     92 	} else {
     93 		if(!vm.count("configfile")){
     94 			std::cout << usage;
     95 			return 1;
     96 		}
     97 		detached_build_config(vm["configfile"].as<std::string>(), vm_configfile);
     98 	}
     99 
    100 	std::vector<std::string> view_str;
    101 	if(vm.count("interactive"))
    102 		view_str=interactive_select_views();
    103 	else
    104 		view_str=detached_select_views(vm_configfile);
    105 
    106 	std::vector<boost::shared_ptr<View> > views=get_views(database_path, view_str);
    107 	Model model;
    108 	for(std::vector<boost::shared_ptr<View> >::iterator it=views.begin(); it!=views.end(); ++it)
    109 		model.add_view(*it);
    110 
    111 	Data_set data_set=classification_svmfile_to_data_set(dataset_path, kinds);
    112 
    113 	if(vm["command"].as<std::string>()=="train"){
    114 		if(vm.count("interactive"))
    115 			interactive_train(model, data_set);
    116 		else
    117 			detached_train(model, data_set, vm_configfile);
    118 		save_views(database_path, views);
    119 	} else if(vm["command"].as<std::string>()=="test"){
    120 		if(vm.count("interactive"))
    121 			interactive_test(model, data_set);
    122 		else
    123 			detached_test(model, data_set, vm_configfile);
    124 	}
    125 }
    126 
    127 void detached_build_config(std::string const &configfile_path, boost::program_options::variables_map &vm){
    128 	namespace po = boost::program_options;
    129 
    130 	po::options_description desc("Config file options");
    131 	desc.add_options()
    132 		("view", po::value<std::vector<std::string> >(), "Add a view to be trained/tested")
    133 		("gradient_step", po::value<float>(), "Set the gradient step for the training algorithm")
    134 		("repeat", po::value<unsigned>(), "Number of time the training algorithm will be repeated")
    135 		("criterion", po::value<std::string>(), "Criterion used to compare the predicted vector to the expected vector")
    136 		;
    137 
    138 	po::store(po::parse_config_file<char>(configfile_path.c_str(), desc, false), vm);
    139 	po::notify(vm);
    140 }
    141 
    142 std::vector<std::string> interactive_select_views(){
    143 	std::vector<std::string> views;
    144 	std::string view;
    145 	while(std::cout << "Add view? ", std::getline(std::cin, view) && view!="end")
    146 		views.push_back(view);
    147 	return views;
    148 }
    149 
    150 std::vector<std::string> detached_select_views(boost::program_options::variables_map const &vm){
    151 	return vm["view"].as<std::vector<std::string> >();
    152 }
    153 
    154 std::vector<boost::shared_ptr<View> > get_views(std::string const &database_path, std::vector<std::string> const &view_str){
    155 	std::vector<boost::shared_ptr<View> > views;
    156 	DB_connection connection(database_path);
    157 	sqlite3_stmt *stmt;
    158 
    159 	for(std::vector<std::string>::const_iterator view=view_str.begin(); view!=view_str.end(); ++view){
    160 		std::stringstream ss;
    161 		ss << "SELECT data FROM views WHERE name='" << *view << "'";
    162 		std::string sql=ss.str();
    163 
    164 		int err=sqlite3_prepare_v2(connection.handle, sql.c_str(), sql.size()+1, &stmt, 0);
    165 		DB_statement statement(stmt);
    166 		if(err!=SQLITE_OK)
    167 			throw std::runtime_error(sqlite3_errmsg(connection.handle));
    168 
    169 		if((err=sqlite3_step(stmt))<100)
    170 			throw std::runtime_error(sqlite3_errmsg(connection.handle));
    171 
    172 		std::stringstream s;
    173 		char const *buf=reinterpret_cast<char const *>(sqlite3_column_blob(stmt, 0));
    174 		int cb=sqlite3_column_bytes(stmt, 0);
    175 		s.write(buf, cb);
    176 		boost::archive::binary_iarchive ai(s);
    177 		views.push_back(boost::shared_ptr<View>());
    178 		ai >> views.back();
    179 	}
    180 	return views;
    181 }
    182 
    183 void interactive_train(Model &model, Data_set &data_set){
    184 	float gradient_step;
    185 	std::cout << "Gradient step? ";
    186 	std::cin >> gradient_step;
    187 
    188 	unsigned repeat;
    189 	std::cout << "Repeat? ";
    190 	std::cin >> repeat;
    191 	std::cin.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
    192 
    193 	std::string criterion;
    194 	std::cout << "Criterion? ";
    195 	std::getline(std::cin, criterion);
    196 
    197 	model.stochastic_learn(data_set, repeat, gradient_step, criterion);
    198 }
    199 
    200 void detached_train(Model &model, Data_set &data_set, boost::program_options::variables_map const &vm){
    201 	if(!vm.count("gradient_step"))
    202 		throw std::runtime_error("Missing 'gradient_step' option in configuration file");
    203 
    204 	if(!vm.count("criterion"))
    205 		throw std::runtime_error("Missing 'criterion' option in configuration file");
    206 
    207 	float gradient_step=vm["gradient_step"].as<float>();
    208 	unsigned repeat=vm.count("repeat")?vm["repeat"].as<unsigned>():1;
    209 	std::string criterion=vm["criterion"].as<std::string>();
    210 
    211 	model.stochastic_learn(data_set, repeat, gradient_step, criterion);
    212 }
    213 
    214 void interactive_test(Model &model, Data_set &data_set){
    215 	std::string criterion;
    216 	std::cout << "Criterion? ";
    217 	std::getline(std::cin, criterion);
    218 
    219 	std::cout << "Error from view " << data_set.kind(0) << " to view " << data_set.kind(0) << ": " << model.error(data_set, criterion, 0, 0) << "\n";
    220 	std::cout << "Error from view " << data_set.kind(0) << " to view " << data_set.kind(1) << ": " << model.error(data_set, criterion, 0, 1) << "\n";
    221 	std::cout << "Error from view " << data_set.kind(1) << " to view " << data_set.kind(0) << ": " << model.error(data_set, criterion, 1, 0) << "\n";
    222 	std::cout << "Error from view " << data_set.kind(1) << " to view " << data_set.kind(1) << ": " << model.error(data_set, criterion, 1, 1) << "\n";
    223 }
    224 
    225 void detached_test(Model &model, Data_set &data_set, boost::program_options::variables_map const &vm){
    226 	if(!vm.count("criterion"))
    227 		throw std::runtime_error("Missing 'criterion' option in configuration file");
    228 	std::string criterion=vm["criterion"].as<std::string>();
    229 
    230 	std::cout << "Error from view " << data_set.kind(0) << " to view " << data_set.kind(0) << ": " << model.error(data_set, criterion, 0, 0) << "\n";
    231 	std::cout << "Error from view " << data_set.kind(0) << " to view " << data_set.kind(1) << ": " << model.error(data_set, criterion, 0, 1) << "\n";
    232 	std::cout << "Error from view " << data_set.kind(1) << " to view " << data_set.kind(0) << ": " << model.error(data_set, criterion, 1, 0) << "\n";
    233 	std::cout << "Error from view " << data_set.kind(1) << " to view " << data_set.kind(1) << ": " << model.error(data_set, criterion, 1, 1) << "\n";
    234 }
    235 
    236 void save_views(std::string const &database_path, std::vector<boost::shared_ptr<View> > const &views){
    237 	DB_connection connection(database_path);
    238 	sqlite3_stmt *stmt;
    239 
    240 	for(std::vector<boost::shared_ptr<View> >::const_iterator view=views.begin(); view!=views.end(); ++view){
    241 		std::stringstream ss;
    242 		ss << "UPDATE views SET data=? WHERE name='" << (*view)->name << "'";
    243 		std::string sql=ss.str();
    244 
    245 		int err=sqlite3_prepare_v2(connection.handle, sql.c_str(), sql.size()+1, &stmt, 0);
    246 		DB_statement statement(stmt);
    247 		if(err!=SQLITE_OK)
    248 			throw std::runtime_error(sqlite3_errmsg(connection.handle));
    249 
    250 		/// @todo Can it be optimised?
    251 		std::stringstream buf;
    252 		boost::archive::binary_oarchive oa(buf);
    253 		oa << *view;
    254 		std::string bufstr=buf.str();
    255 
    256 		if(sqlite3_bind_blob(stmt, 1, bufstr.c_str(), bufstr.size(), SQLITE_STATIC)!=SQLITE_OK)
    257 			throw std::runtime_error(sqlite3_errmsg(connection.handle));
    258 
    259 		if(sqlite3_step(stmt)!=SQLITE_DONE)
    260 			throw std::runtime_error(sqlite3_errmsg(connection.handle));
    261 	}
    262 }
    263