ext_saveload.py (1000B)
1 import cPickle 2 import logging 3 4 from blocks.extensions import SimpleExtension 5 6 logger = logging.getLogger(__name__) 7 8 class SaveLoadParams(SimpleExtension): 9 def __init__(self, path, model, **kwargs): 10 super(SaveLoadParams, self).__init__(**kwargs) 11 12 self.path = path 13 self.model = model 14 15 def do_save(self): 16 with open(self.path, 'w') as f: 17 logger.info('Saving parameters to %s...'%self.path) 18 cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) 19 logger.info('Done saving.') 20 21 def do_load(self): 22 try: 23 with open(self.path, 'r') as f: 24 logger.info('Loading parameters from %s...'%self.path) 25 self.model.set_parameter_values(cPickle.load(f)) 26 except IOError: 27 pass 28 29 def do(self, which_callback, *args): 30 if which_callback == 'before_training': 31 self.do_load() 32 else: 33 self.do_save()