taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

mlp.py (5804B)


      1 from theano import tensor
      2 import numpy
      3 
      4 import fuel
      5 import blocks
      6 
      7 from fuel.transformers import Batch, MultiProcessing, Mapping, SortMapping, Unpack
      8 from fuel.streams import DataStream
      9 from fuel.schemes import ConstantScheme, ShuffledExampleScheme
     10 from blocks.bricks import application, MLP, Rectifier, Initializable
     11 
     12 import data
     13 from data import transformers
     14 from data.hdf5 import TaxiDataset, TaxiStream
     15 from data.cut import TaxiTimeCutScheme
     16 from model import ContextEmbedder
     17 
     18 blocks.config.default_seed = 123
     19 fuel.config.default_seed = 123
     20 
     21 class FFMLP(Initializable):
     22     def __init__(self, config, output_layer=None, **kwargs):
     23         super(FFMLP, self).__init__(**kwargs)
     24         self.config = config
     25 
     26         self.context_embedder = ContextEmbedder(config)
     27 
     28         output_activation = [] if output_layer is None else [output_layer()]
     29         output_dim = [] if output_layer is None else [config.dim_output]
     30         self.mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + output_activation,
     31                        dims=[config.dim_input] + config.dim_hidden + output_dim)
     32 
     33         self.extremities = {'%s_k_%s' % (side, ['latitude', 'longitude'][axis]): axis for side in ['first', 'last'] for axis in [0, 1]}
     34         self.inputs = self.context_embedder.inputs + self.extremities.keys()
     35         self.children = [ self.context_embedder, self.mlp ]
     36 
     37     def _push_initialization_config(self):
     38         self.mlp.weights_init = self.config.mlp_weights_init
     39         self.mlp.biases_init = self.config.mlp_biases_init
     40 
     41     @application(outputs=['prediction'])
     42     def predict(self, **kwargs):
     43         embeddings = tuple(self.context_embedder.apply(**{k: kwargs[k] for k in self.context_embedder.inputs }))
     44         extremities = tuple((kwargs[k] - data.train_gps_mean[v]) / data.train_gps_std[v] for k, v in self.extremities.items())
     45 
     46         inputs = tensor.concatenate(extremities + embeddings, axis=1)
     47         outputs = self.mlp.apply(inputs)
     48 
     49         return outputs
     50 
     51     @predict.property('inputs')
     52     def predict_inputs(self):
     53         return self.inputs
     54 
     55 class UniformGenerator(object):
     56     def __init__(self):
     57         self.rng = numpy.random.RandomState(123)
     58     def __call__(self, *args):
     59         return float(self.rng.uniform())
     60 
     61 class Stream(object):
     62     def __init__(self, config):
     63         self.config = config
     64 
     65     def train(self, req_vars):
     66         stream = TaxiDataset('train', data.traintest_ds)
     67 
     68         if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
     69             stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
     70         else:
     71             stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
     72 
     73         if not data.tvt:
     74             valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
     75             valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
     76             stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
     77 
     78         stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
     79 
     80         if hasattr(self.config, 'shuffle_batch_size'):
     81             stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size))
     82             stream = Mapping(stream, SortMapping(key=UniformGenerator()))
     83             stream = Unpack(stream)
     84 
     85         stream = transformers.taxi_add_datetime(stream)
     86         stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
     87         stream = transformers.Select(stream, tuple(req_vars))
     88         
     89         stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
     90 
     91         stream = MultiProcessing(stream)
     92 
     93         return stream
     94 
     95     def valid(self, req_vars):
     96         stream = TaxiStream(data.valid_set, data.valid_ds)
     97 
     98         stream = transformers.taxi_add_datetime(stream)
     99         stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
    100         stream = transformers.Select(stream, tuple(req_vars))
    101         return Batch(stream, iteration_scheme=ConstantScheme(1000))
    102 
    103     def test(self, req_vars):
    104         stream = TaxiStream('test', data.traintest_ds)
    105         
    106         stream = transformers.taxi_add_datetime(stream)
    107         stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
    108         stream = transformers.taxi_remove_test_only_clients(stream)
    109 
    110         return Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
    111 
    112     def inputs(self):
    113         return {'call_type': tensor.bvector('call_type'),
    114                 'origin_call': tensor.ivector('origin_call'),
    115                 'origin_stand': tensor.bvector('origin_stand'),
    116                 'taxi_id': tensor.wvector('taxi_id'),
    117                 'timestamp': tensor.ivector('timestamp'),
    118                 'day_type': tensor.bvector('day_type'),
    119                 'missing_data': tensor.bvector('missing_data'),
    120                 'latitude': tensor.matrix('latitude'),
    121                 'longitude': tensor.matrix('longitude'),
    122                 'destination_latitude': tensor.vector('destination_latitude'),
    123                 'destination_longitude': tensor.vector('destination_longitude'),
    124                 'travel_time': tensor.ivector('travel_time'),
    125                 'first_k_latitude': tensor.matrix('first_k_latitude'),
    126                 'first_k_longitude': tensor.matrix('first_k_longitude'),
    127                 'last_k_latitude': tensor.matrix('last_k_latitude'),
    128                 'last_k_longitude': tensor.matrix('last_k_longitude'),
    129                 'input_time': tensor.ivector('input_time'),
    130                 'week_of_year': tensor.bvector('week_of_year'),
    131                 'day_of_week': tensor.bvector('day_of_week'),
    132                 'qhour_of_day': tensor.bvector('qhour_of_day')}