commit 389d8001be77e6cacb35804236fe9d3f0930282b
parent 66d2717188e189fde5422576740903ca8e488f63
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date: Mon, 6 Jul 2015 10:40:23 -0400
Blocks compatibility
Diffstat:
2 files changed, 14 insertions(+), 10 deletions(-)
diff --git a/ext_saveload.py b/ext_saveload.py
@@ -15,14 +15,14 @@ class SaveLoadParams(SimpleExtension):
def do_save(self):
with open(self.path, 'w') as f:
logger.info('Saving parameters to %s...'%self.path)
- cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+ cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
logger.info('Done saving.')
def do_load(self):
try:
with open(self.path, 'r') as f:
logger.info('Loading parameters from %s...'%self.path)
- self.model.set_param_values(cPickle.load(f))
+ self.model.set_parameter_values(cPickle.load(f))
except IOError:
pass
diff --git a/train.py b/train.py
@@ -21,7 +21,7 @@ blocks.config.default_seed = 123
fuel.config.default_seed = 123
try:
- from blocks.extras.extensions.plotting import Plot
+ from blocks.extras.extensions.plot import Plot
use_plot = True
except ImportError:
use_plot = False
@@ -77,10 +77,10 @@ if __name__ == "__main__":
logger.info('# Parameter shapes:')
parameters_size = 0
- for key, value in cg.get_params().iteritems():
- logger.info(' %20s %s' % (value.get_value().shape, key))
+ for value in cg.parameters:
+ logger.info(' %20s %s' % (value.get_value().shape, value.name))
parameters_size += reduce(operator.mul, value.get_value().shape, 1)
- logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.get_params())))
+ logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.parameters)))
if hasattr(config, 'step_rule'):
step_rule = config.step_rule
@@ -97,9 +97,10 @@ if __name__ == "__main__":
RemoveNotFinite(),
step_rule
]),
- params=params)
+ parameters=params)
- plot_vars = [['valid_' + x.name for x in valid_monitored]]
+ plot_vars = [['valid_' + x.name for x in valid_monitored] +
+ ['train_' + x.name for x in valid_monitored]]
logger.info('Plotted variables: %s' % str(plot_vars))
dump_path = os.path.join('model_data', model_name) + '.pkl'
@@ -110,7 +111,7 @@ if __name__ == "__main__":
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),
- # FinishAfter(every_n_batches=10),
+ FinishAfter(every_n_batches=10000000),
SaveLoadParams(dump_path, cg,
before_training=True, # before training -> load params
@@ -126,7 +127,10 @@ if __name__ == "__main__":
]
if use_plot:
- extensions.append(Plot(model_name, channels=plot_vars, every_n_batches=500))
+ extensions.append(Plot(model_name,
+ channels=plot_vars,
+ every_n_batches=500,
+ server_url='http://eos6:5006/'))
main_loop = MainLoop(
model=cg,