train.py (5219B)
1 #!/usr/bin/env python2 2 3 import importlib 4 import logging 5 import operator 6 import os 7 import sys 8 from functools import reduce 9 10 from theano import tensor 11 12 import blocks 13 import fuel 14 15 from blocks import roles 16 from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum 17 from blocks.extensions import Printing, FinishAfter, ProgressBar 18 from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring 19 20 blocks.config.default_seed = 123 21 fuel.config.default_seed = 123 22 23 try: 24 from blocks.extras.extensions.plot import Plot 25 use_plot = True 26 except ImportError: 27 use_plot = False 28 29 from blocks.filter import VariableFilter 30 from blocks.graph import ComputationGraph, apply_dropout, apply_noise 31 from blocks.main_loop import MainLoop 32 from blocks.model import Model 33 34 from ext_saveload import SaveLoadParams 35 from ext_test import RunOnTest 36 37 logger = logging.getLogger(__name__) 38 39 if __name__ == "__main__": 40 if len(sys.argv) < 2 or len(sys.argv) > 4: 41 print >> sys.stderr, 'Usage: %s [--tvt | --largevalid] [--progress] config' % sys.argv[0] 42 sys.exit(1) 43 model_name = sys.argv[-1] 44 config = importlib.import_module('.%s' % model_name, 'config') 45 46 logger.info('# Configuration: %s' % config.__name__) 47 for key in dir(config): 48 if not key.startswith('__') and isinstance(getattr(config, key), (int, str, list, tuple)): 49 logger.info(' %20s %s' % (key, str(getattr(config, key)))) 50 51 model = config.Model(config) 52 model.initialize() 53 54 stream = config.Stream(config) 55 inputs = stream.inputs() 56 req_vars = model.cost.inputs 57 58 train_stream = stream.train(req_vars) 59 valid_stream = stream.valid(req_vars) 60 61 cost = model.cost(**inputs) 62 cg = ComputationGraph(cost) 63 monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables)) 64 65 valid_monitored = monitored 66 if hasattr(model, 'valid_cost'): 67 valid_cost = model.valid_cost(**inputs) 68 valid_cg = ComputationGraph(valid_cost) 69 valid_monitored = set([valid_cost] + VariableFilter(roles=[roles.COST])(valid_cg.variables)) 70 71 if hasattr(config, 'dropout') and config.dropout < 1.0: 72 cg = apply_dropout(cg, config.dropout_inputs(cg), config.dropout) 73 if hasattr(config, 'noise') and config.noise > 0.0: 74 cg = apply_noise(cg, config.noise_inputs(cg), config.noise) 75 cost = cg.outputs[0] 76 cg = Model(cost) 77 78 logger.info('# Parameter shapes:') 79 parameters_size = 0 80 for value in cg.parameters: 81 logger.info(' %20s %s' % (value.get_value().shape, value.name)) 82 parameters_size += reduce(operator.mul, value.get_value().shape, 1) 83 logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.parameters))) 84 85 if hasattr(config, 'step_rule'): 86 step_rule = config.step_rule 87 else: 88 step_rule = AdaDelta() 89 90 logger.info("Fuel seed: %d" % fuel.config.default_seed) 91 logger.info("Blocks seed: %d" % blocks.config.default_seed) 92 93 params = cg.parameters 94 algorithm = GradientDescent( 95 cost=cost, 96 step_rule=CompositeRule([ 97 RemoveNotFinite(), 98 step_rule 99 ]), 100 parameters=params) 101 102 plot_vars = [['valid_' + x.name for x in valid_monitored] + 103 ['train_' + x.name for x in valid_monitored]] 104 logger.info('Plotted variables: %s' % str(plot_vars)) 105 106 dump_path = os.path.join('model_data', model_name) + '.pkl' 107 logger.info('Dump path: %s' % dump_path) 108 109 if hasattr(config, 'monitor_freq'): 110 monitor_freq = config.monitor_freq 111 else: 112 monitor_freq = 10000 113 114 extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=monitor_freq), 115 DataStreamMonitoring(valid_monitored, valid_stream, 116 prefix='valid', 117 every_n_batches=monitor_freq, 118 after_epoch=False), 119 Printing(every_n_batches=monitor_freq), 120 FinishAfter(every_n_batches=10000000), 121 122 SaveLoadParams(dump_path, cg, 123 before_training=True, # before training -> load params 124 every_n_batches=monitor_freq,# every N batches -> save params 125 after_epoch=False, 126 after_training=True, # after training -> save params 127 ), 128 129 RunOnTest(model_name, 130 model, 131 stream, 132 every_n_batches=monitor_freq), 133 ] 134 135 if '--progress' in sys.argv: 136 extensions.append(ProgressBar()) 137 138 if use_plot: 139 extensions.append(Plot(model_name, 140 channels=plot_vars, 141 every_n_batches=500, 142 server_url='http://eos6:5006/')) 143 144 main_loop = MainLoop( 145 model=cg, 146 data_stream=train_stream, 147 algorithm=algorithm, 148 extensions=extensions) 149 main_loop.run() 150 main_loop.profile.report()