taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

ext_saveload.py (1000B)


      1 import cPickle
      2 import logging
      3 
      4 from blocks.extensions import SimpleExtension
      5 
      6 logger = logging.getLogger(__name__)
      7 
      8 class SaveLoadParams(SimpleExtension):
      9     def __init__(self, path, model, **kwargs):
     10         super(SaveLoadParams, self).__init__(**kwargs)
     11 
     12         self.path = path
     13         self.model = model
     14     
     15     def do_save(self):
     16         with open(self.path, 'w') as f:
     17             logger.info('Saving parameters to %s...'%self.path)
     18             cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
     19             logger.info('Done saving.')
     20     
     21     def do_load(self):
     22         try:
     23             with open(self.path, 'r') as f:
     24                 logger.info('Loading parameters from %s...'%self.path)
     25                 self.model.set_parameter_values(cPickle.load(f))
     26         except IOError:
     27             pass
     28 
     29     def do(self, which_callback, *args):
     30         if which_callback == 'before_training':
     31             self.do_load()
     32         else:
     33             self.do_save()