taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

train.py (5219B)


      1 #!/usr/bin/env python2
      2 
      3 import importlib
      4 import logging
      5 import operator
      6 import os
      7 import sys
      8 from functools import reduce
      9 
     10 from theano import tensor
     11 
     12 import blocks
     13 import fuel
     14 
     15 from blocks import roles
     16 from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum
     17 from blocks.extensions import Printing, FinishAfter, ProgressBar
     18 from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
     19 
     20 blocks.config.default_seed = 123
     21 fuel.config.default_seed = 123
     22 
     23 try:
     24     from blocks.extras.extensions.plot import Plot
     25     use_plot = True
     26 except ImportError:
     27     use_plot = False
     28     
     29 from blocks.filter import VariableFilter
     30 from blocks.graph import ComputationGraph, apply_dropout, apply_noise
     31 from blocks.main_loop import MainLoop
     32 from blocks.model import Model
     33 
     34 from ext_saveload import SaveLoadParams
     35 from ext_test import RunOnTest
     36 
     37 logger = logging.getLogger(__name__)
     38 
     39 if __name__ == "__main__":
     40     if len(sys.argv) < 2 or len(sys.argv) > 4:
     41         print >> sys.stderr, 'Usage: %s [--tvt | --largevalid] [--progress] config' % sys.argv[0]
     42         sys.exit(1)
     43     model_name = sys.argv[-1]
     44     config = importlib.import_module('.%s' % model_name, 'config')
     45 
     46     logger.info('# Configuration: %s' % config.__name__)
     47     for key in dir(config):
     48         if not key.startswith('__') and isinstance(getattr(config, key), (int, str, list, tuple)):
     49             logger.info('    %20s %s' % (key, str(getattr(config, key))))
     50 
     51     model = config.Model(config)
     52     model.initialize()
     53 
     54     stream = config.Stream(config)
     55     inputs = stream.inputs()
     56     req_vars = model.cost.inputs
     57 
     58     train_stream = stream.train(req_vars)
     59     valid_stream = stream.valid(req_vars)
     60 
     61     cost = model.cost(**inputs)
     62     cg = ComputationGraph(cost)
     63     monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables))
     64 
     65     valid_monitored = monitored
     66     if hasattr(model, 'valid_cost'):
     67         valid_cost = model.valid_cost(**inputs)
     68         valid_cg = ComputationGraph(valid_cost)
     69         valid_monitored = set([valid_cost] + VariableFilter(roles=[roles.COST])(valid_cg.variables))
     70 
     71     if hasattr(config, 'dropout') and config.dropout < 1.0:
     72         cg = apply_dropout(cg, config.dropout_inputs(cg), config.dropout)
     73     if hasattr(config, 'noise') and config.noise > 0.0:
     74         cg = apply_noise(cg, config.noise_inputs(cg), config.noise)
     75     cost = cg.outputs[0]
     76     cg = Model(cost)
     77 
     78     logger.info('# Parameter shapes:')
     79     parameters_size = 0
     80     for value in cg.parameters:
     81         logger.info('    %20s %s' % (value.get_value().shape, value.name))
     82         parameters_size += reduce(operator.mul, value.get_value().shape, 1)
     83     logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.parameters)))
     84 
     85     if hasattr(config, 'step_rule'):
     86         step_rule = config.step_rule
     87     else:
     88         step_rule = AdaDelta()
     89 
     90     logger.info("Fuel seed: %d" % fuel.config.default_seed)
     91     logger.info("Blocks seed: %d" % blocks.config.default_seed)
     92 
     93     params = cg.parameters
     94     algorithm = GradientDescent(
     95         cost=cost,
     96         step_rule=CompositeRule([
     97                 RemoveNotFinite(),
     98                 step_rule
     99             ]),
    100         parameters=params)
    101     
    102     plot_vars = [['valid_' + x.name for x in valid_monitored] +
    103                  ['train_' + x.name for x in valid_monitored]]
    104     logger.info('Plotted variables: %s' % str(plot_vars))
    105 
    106     dump_path = os.path.join('model_data', model_name) + '.pkl'
    107     logger.info('Dump path: %s' % dump_path)
    108 
    109     if hasattr(config, 'monitor_freq'):
    110         monitor_freq = config.monitor_freq
    111     else:
    112         monitor_freq = 10000
    113 
    114     extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=monitor_freq),
    115                 DataStreamMonitoring(valid_monitored, valid_stream,
    116                                      prefix='valid',
    117                                      every_n_batches=monitor_freq,
    118                                      after_epoch=False),
    119                 Printing(every_n_batches=monitor_freq),
    120                 FinishAfter(every_n_batches=10000000),
    121 
    122                 SaveLoadParams(dump_path, cg,
    123                                before_training=True,        # before training -> load params
    124                                every_n_batches=monitor_freq,# every N batches -> save params
    125                                after_epoch=False,
    126                                after_training=True,         # after training -> save params
    127                                ),
    128 
    129                 RunOnTest(model_name,
    130                           model,
    131                           stream,
    132                           every_n_batches=monitor_freq),
    133                 ]
    134 
    135     if '--progress' in sys.argv:
    136         extensions.append(ProgressBar())
    137     
    138     if use_plot:
    139         extensions.append(Plot(model_name,
    140                                channels=plot_vars,
    141                                every_n_batches=500,
    142                                server_url='http://eos6:5006/'))
    143 
    144     main_loop = MainLoop(
    145         model=cg,
    146         data_stream=train_stream,
    147         algorithm=algorithm,
    148         extensions=extensions)
    149     main_loop.run()
    150     main_loop.profile.report()