supervisor.cpp (10408B)
1 /** 2 * @file 3 * @brief Interface to the training and error algorithms. 4 */ 5 6 #include <iostream> 7 #include <limits> 8 #include <sstream> 9 #include <stdexcept> 10 #include <string> 11 #include <vector> 12 13 #include <boost/archive/binary_iarchive.hpp> 14 #include <boost/archive/binary_oarchive.hpp> 15 #include <boost/program_options.hpp> 16 #include <boost/shared_ptr.hpp> 17 #include <sqlite3.h> 18 19 #include "data_set.hpp" 20 #include "nmlp_iostream_exports.hpp" 21 #include "sqlite_utils.hpp" 22 #include "model.hpp" 23 24 static char const * const usage = "usage:\n" 25 "\tsupervisor --help\n" 26 "\tsupervisor ( train | test ) -d database dataset configfile kind*\n" 27 "\tsupervisor ( train | test ) --interactive -d database dataset kind*\n"; 28 29 /** @brief Fill a given variable_map with the variable in a given configuration file. */ 30 void detached_build_config(std::string const &configfile_path, boost::program_options::variables_map &vm); 31 32 /** @brief Select a set of View interactively. */ 33 std::vector<std::string> interactive_select_views(); 34 35 /** @brief Select a set of View from a given configuration. */ 36 std::vector<std::string> detached_select_views(boost::program_options::variables_map const &vm); 37 38 /** @brief Fetch the given View objects from a given database. */ 39 std::vector<boost::shared_ptr<View> > get_views(std::string const &database_path, std::vector<std::string> const &view_str); 40 41 /** @brief Train a given Model, ask for the parameters interactively. */ 42 void interactive_train(Model &model, Data_set &data_set); 43 44 /** @brief Train a given Model, get the parameters from a given configuration. */ 45 void detached_train(Model &model, Data_set &data_set, boost::program_options::variables_map const &vm); 46 47 /** @brief Test a given Model, ask for the parameters interactively. */ 48 void interactive_test(Model &model, Data_set &data_set); 49 50 /** @brief Test a given Model, get the parameters from a given configuration. */ 51 void detached_test(Model &model, Data_set &data_set, boost::program_options::variables_map const &vm); 52 53 /** @brief Update a given View database with a set of newly trained View objects. */ 54 void save_views(std::string const &database_path, std::vector<boost::shared_ptr<View> > const &views); 55 56 int main(int argc, char **argv){ 57 namespace po = boost::program_options; 58 std::string database_path, dataset_path; 59 std::vector<std::string> kinds; 60 61 po::options_description desc("Allowed options"); 62 desc.add_options() 63 ("command,a", po::value<std::string>(), "Select the operation to execute") 64 ("configfile,c", po::value<std::string>(), "Configuration file containing the training/testing parameters") 65 ("dataset,s", po::value<std::string>(&dataset_path), "Data on which the model will be trained/tested") 66 ("database,d", po::value<std::string>(&database_path), "Select the view database") 67 ("help,h", "Print a summary of the options") 68 ("interactive,i", "Start in interactive mode") 69 ("kind,k", po::value<std::vector<std::string> >(&kinds), "Set the kinds present in the dataset") 70 ; 71 72 po::positional_options_description p; 73 p.add("command", 1).add("dataset", 1).add("configfile", 1).add("kind", -1); 74 75 po::variables_map vm, vm_configfile; 76 po::store(po::command_line_parser(argc, argv).options(desc).positional(p).run(), vm); 77 po::notify(vm); 78 79 if(vm.count("help")){ 80 std::cout << usage << "\n" << desc << "\n"; 81 return 0; 82 } 83 84 if(!vm.count("database") || !vm.count("command")){ 85 std::cout << usage; 86 return 1; 87 } 88 89 if(vm.count("interactive")){ 90 if(vm.count("configfile")) 91 kinds.push_back(vm["configfile"].as<std::string>()); 92 } else { 93 if(!vm.count("configfile")){ 94 std::cout << usage; 95 return 1; 96 } 97 detached_build_config(vm["configfile"].as<std::string>(), vm_configfile); 98 } 99 100 std::vector<std::string> view_str; 101 if(vm.count("interactive")) 102 view_str=interactive_select_views(); 103 else 104 view_str=detached_select_views(vm_configfile); 105 106 std::vector<boost::shared_ptr<View> > views=get_views(database_path, view_str); 107 Model model; 108 for(std::vector<boost::shared_ptr<View> >::iterator it=views.begin(); it!=views.end(); ++it) 109 model.add_view(*it); 110 111 Data_set data_set=classification_svmfile_to_data_set(dataset_path, kinds); 112 113 if(vm["command"].as<std::string>()=="train"){ 114 if(vm.count("interactive")) 115 interactive_train(model, data_set); 116 else 117 detached_train(model, data_set, vm_configfile); 118 save_views(database_path, views); 119 } else if(vm["command"].as<std::string>()=="test"){ 120 if(vm.count("interactive")) 121 interactive_test(model, data_set); 122 else 123 detached_test(model, data_set, vm_configfile); 124 } 125 } 126 127 void detached_build_config(std::string const &configfile_path, boost::program_options::variables_map &vm){ 128 namespace po = boost::program_options; 129 130 po::options_description desc("Config file options"); 131 desc.add_options() 132 ("view", po::value<std::vector<std::string> >(), "Add a view to be trained/tested") 133 ("gradient_step", po::value<float>(), "Set the gradient step for the training algorithm") 134 ("repeat", po::value<unsigned>(), "Number of time the training algorithm will be repeated") 135 ("criterion", po::value<std::string>(), "Criterion used to compare the predicted vector to the expected vector") 136 ; 137 138 po::store(po::parse_config_file<char>(configfile_path.c_str(), desc, false), vm); 139 po::notify(vm); 140 } 141 142 std::vector<std::string> interactive_select_views(){ 143 std::vector<std::string> views; 144 std::string view; 145 while(std::cout << "Add view? ", std::getline(std::cin, view) && view!="end") 146 views.push_back(view); 147 return views; 148 } 149 150 std::vector<std::string> detached_select_views(boost::program_options::variables_map const &vm){ 151 return vm["view"].as<std::vector<std::string> >(); 152 } 153 154 std::vector<boost::shared_ptr<View> > get_views(std::string const &database_path, std::vector<std::string> const &view_str){ 155 std::vector<boost::shared_ptr<View> > views; 156 DB_connection connection(database_path); 157 sqlite3_stmt *stmt; 158 159 for(std::vector<std::string>::const_iterator view=view_str.begin(); view!=view_str.end(); ++view){ 160 std::stringstream ss; 161 ss << "SELECT data FROM views WHERE name='" << *view << "'"; 162 std::string sql=ss.str(); 163 164 int err=sqlite3_prepare_v2(connection.handle, sql.c_str(), sql.size()+1, &stmt, 0); 165 DB_statement statement(stmt); 166 if(err!=SQLITE_OK) 167 throw std::runtime_error(sqlite3_errmsg(connection.handle)); 168 169 if((err=sqlite3_step(stmt))<100) 170 throw std::runtime_error(sqlite3_errmsg(connection.handle)); 171 172 std::stringstream s; 173 char const *buf=reinterpret_cast<char const *>(sqlite3_column_blob(stmt, 0)); 174 int cb=sqlite3_column_bytes(stmt, 0); 175 s.write(buf, cb); 176 boost::archive::binary_iarchive ai(s); 177 views.push_back(boost::shared_ptr<View>()); 178 ai >> views.back(); 179 } 180 return views; 181 } 182 183 void interactive_train(Model &model, Data_set &data_set){ 184 float gradient_step; 185 std::cout << "Gradient step? "; 186 std::cin >> gradient_step; 187 188 unsigned repeat; 189 std::cout << "Repeat? "; 190 std::cin >> repeat; 191 std::cin.ignore(std::numeric_limits<std::streamsize>::max(), '\n'); 192 193 std::string criterion; 194 std::cout << "Criterion? "; 195 std::getline(std::cin, criterion); 196 197 model.stochastic_learn(data_set, repeat, gradient_step, criterion); 198 } 199 200 void detached_train(Model &model, Data_set &data_set, boost::program_options::variables_map const &vm){ 201 if(!vm.count("gradient_step")) 202 throw std::runtime_error("Missing 'gradient_step' option in configuration file"); 203 204 if(!vm.count("criterion")) 205 throw std::runtime_error("Missing 'criterion' option in configuration file"); 206 207 float gradient_step=vm["gradient_step"].as<float>(); 208 unsigned repeat=vm.count("repeat")?vm["repeat"].as<unsigned>():1; 209 std::string criterion=vm["criterion"].as<std::string>(); 210 211 model.stochastic_learn(data_set, repeat, gradient_step, criterion); 212 } 213 214 void interactive_test(Model &model, Data_set &data_set){ 215 std::string criterion; 216 std::cout << "Criterion? "; 217 std::getline(std::cin, criterion); 218 219 std::cout << "Error from view " << data_set.kind(0) << " to view " << data_set.kind(0) << ": " << model.error(data_set, criterion, 0, 0) << "\n"; 220 std::cout << "Error from view " << data_set.kind(0) << " to view " << data_set.kind(1) << ": " << model.error(data_set, criterion, 0, 1) << "\n"; 221 std::cout << "Error from view " << data_set.kind(1) << " to view " << data_set.kind(0) << ": " << model.error(data_set, criterion, 1, 0) << "\n"; 222 std::cout << "Error from view " << data_set.kind(1) << " to view " << data_set.kind(1) << ": " << model.error(data_set, criterion, 1, 1) << "\n"; 223 } 224 225 void detached_test(Model &model, Data_set &data_set, boost::program_options::variables_map const &vm){ 226 if(!vm.count("criterion")) 227 throw std::runtime_error("Missing 'criterion' option in configuration file"); 228 std::string criterion=vm["criterion"].as<std::string>(); 229 230 std::cout << "Error from view " << data_set.kind(0) << " to view " << data_set.kind(0) << ": " << model.error(data_set, criterion, 0, 0) << "\n"; 231 std::cout << "Error from view " << data_set.kind(0) << " to view " << data_set.kind(1) << ": " << model.error(data_set, criterion, 0, 1) << "\n"; 232 std::cout << "Error from view " << data_set.kind(1) << " to view " << data_set.kind(0) << ": " << model.error(data_set, criterion, 1, 0) << "\n"; 233 std::cout << "Error from view " << data_set.kind(1) << " to view " << data_set.kind(1) << ": " << model.error(data_set, criterion, 1, 1) << "\n"; 234 } 235 236 void save_views(std::string const &database_path, std::vector<boost::shared_ptr<View> > const &views){ 237 DB_connection connection(database_path); 238 sqlite3_stmt *stmt; 239 240 for(std::vector<boost::shared_ptr<View> >::const_iterator view=views.begin(); view!=views.end(); ++view){ 241 std::stringstream ss; 242 ss << "UPDATE views SET data=? WHERE name='" << (*view)->name << "'"; 243 std::string sql=ss.str(); 244 245 int err=sqlite3_prepare_v2(connection.handle, sql.c_str(), sql.size()+1, &stmt, 0); 246 DB_statement statement(stmt); 247 if(err!=SQLITE_OK) 248 throw std::runtime_error(sqlite3_errmsg(connection.handle)); 249 250 /// @todo Can it be optimised? 251 std::stringstream buf; 252 boost::archive::binary_oarchive oa(buf); 253 oa << *view; 254 std::string bufstr=buf.str(); 255 256 if(sqlite3_bind_blob(stmt, 1, bufstr.c_str(), bufstr.size(), SQLITE_STATIC)!=SQLITE_OK) 257 throw std::runtime_error(sqlite3_errmsg(connection.handle)); 258 259 if(sqlite3_step(stmt)!=SQLITE_DONE) 260 throw std::runtime_error(sqlite3_errmsg(connection.handle)); 261 } 262 } 263