taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

bidirectional_tgtcls_window.py (8606B)


      1 from model.bidirectional import SegregatedBidirectional
      2 
      3 
      4 class Model(Initializable):
      5     @lazy()
      6     def __init__(self, config, output_dim=2, **kwargs):
      7         super(Model, self).__init__(**kwargs)
      8         self.config = config
      9 
     10         self.context_embedder = ContextEmbedder(config)
     11         
     12         act = config.rec_activation() if hasattr(config, 'rec_activation') else None
     13         self.rec = SegregatedBidirectional(LSTM(dim=config.hidden_state_dim, activation=act,
     14                                                 name='recurrent'))
     15 
     16         self.fwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'],
     17                              prototype=Linear(), name='fwd_fork')
     18         self.bkwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'],
     19                               prototype=Linear(), name='bkwd_fork')
     20 
     21         rto_in = config.hidden_state_dim * 2 + sum(x[2] for x in config.dim_embeddings)
     22         self.rec_to_output = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], 
     23                                  dims=[rto_in] + config.dim_hidden + [output_dim])
     24 
     25         self.softmax = Softmax()
     26 
     27         self.sequences = ['latitude', 'latitude_mask', 'longitude']
     28         self.inputs = self.sequences + self.context_embedder.inputs
     29 
     30         self.children = [ self.context_embedder, self.fwd_fork, self.bkwd_fork,
     31                           self.rec, self.rec_to_output, self.softmax ]
     32 
     33         self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX),
     34                                      name='classes')
     35 
     36     def _push_allocation_config(self):
     37         for i, fork in enumerate([self.fwd_fork, self.bkwd_fork]):
     38             fork.input_dim = 2 * self.config.window_size
     39             fork.output_dims = [ self.rec.children[i].get_dim(name)
     40                                  for name in fork.output_names ]
     41 
     42     def _push_initialization_config(self):
     43         for brick in [self.fwd_fork, self.bkwd_fork, self.rec, self.rec_to_output]:
     44             brick.weights_init = self.config.weights_init
     45             brick.biases_init = self.config.biases_init
     46 
     47     def process_outputs(self, outputs):
     48         return tensor.dot(self.softmax.apply(outputs), self.classes)
     49 
     50     @application(outputs=['destination'])
     51     def predict(self, latitude, longitude, latitude_mask, **kwargs):
     52         latitude = (latitude.dimshuffle(1, 0, 2) - data.train_gps_mean[0]) / data.train_gps_std[0]
     53         longitude = (longitude.dimshuffle(1, 0, 2) - data.train_gps_mean[1]) / data.train_gps_std[1]
     54         latitude_mask = latitude_mask.T
     55 
     56         rec_in = tensor.concatenate((latitude, longitude), axis=2)
     57 
     58         last_id = tensor.cast(latitude_mask.sum(axis=0) - 1, dtype='int64')
     59 
     60         path = self.rec.apply(merge(self.fwd_fork.apply(rec_in, as_dict=True),
     61                                     {'mask': latitude_mask}),
     62                               merge(self.bkwd_fork.apply(rec_in, as_dict=True),
     63                                     {'mask': latitude_mask}))[0]
     64 
     65         path_representation = (path[0][:, -self.config.hidden_state_dim:],
     66                                path[last_id - 1, tensor.arange(latitude_mask.shape[1])]
     67                                    [:, :self.config.hidden_state_dim])
     68 
     69         embeddings = tuple(self.context_embedder.apply(
     70                         **{k: kwargs[k] for k in self.context_embedder.inputs }))
     71 
     72         inputs = tensor.concatenate(path_representation + embeddings, axis=1)
     73         outputs = self.rec_to_output.apply(inputs)
     74 
     75         return self.process_outputs(outputs)
     76 
     77     @predict.property('inputs')
     78     def predict_inputs(self):
     79         return self.inputs
     80 
     81     @application(outputs=['cost'])
     82     def cost(self, **kwargs):
     83         y_hat = self.predict(**kwargs)
     84         y = tensor.concatenate((kwargs['destination_latitude'][:, None],
     85                                 kwargs['destination_longitude'][:, None]), axis=1)
     86 
     87         return error.erdist(y_hat, y).mean()
     88 
     89     @cost.property('inputs')
     90     def cost_inputs(self):
     91         return self.inputs + ['destination_latitude', 'destination_longitude']
     92 
     93 
     94 
     95 class Stream(object):
     96     def __init__(self, config):
     97         self.config = config
     98 
     99     def train(self, req_vars):
    100         stream = TaxiDataset('train', data.traintest_ds)
    101 
    102         if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
    103             stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
    104         else:
    105             stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
    106 
    107         if not data.tvt:
    108             valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
    109             valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
    110             stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
    111 
    112         if hasattr(self.config, 'max_splits'):
    113             stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
    114         elif not data.tvt:
    115             stream = transformers.add_destination(stream)
    116 
    117         if hasattr(self.config, 'train_max_len'):
    118             idx = stream.sources.index('latitude')
    119             def max_len_filter(x):
    120                 return len(x[idx]) <= self.config.train_max_len
    121             stream = Filter(stream, max_len_filter)
    122 
    123         stream = transformers.TaxiExcludeEmptyTrips(stream)
    124 
    125         stream = transformers.window(stream, config.window_size)
    126         
    127         stream = transformers.taxi_add_datetime(stream)
    128         stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
    129 
    130         stream = transformers.balanced_batch(stream, key='latitude',
    131                                              batch_size=self.config.batch_size,
    132                                              batch_sort_size=self.config.batch_sort_size)
    133         stream = Padding(stream, mask_sources=['latitude', 'longitude'])
    134         stream = transformers.Select(stream, req_vars)
    135         stream = MultiProcessing(stream)
    136 
    137         return stream
    138 
    139     def valid(self, req_vars):
    140         stream = TaxiStream(data.valid_set, data.valid_ds)
    141 
    142         stream = transformers.window(stream, config.window_size)
    143 
    144         stream = transformers.taxi_add_datetime(stream)
    145         stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
    146 
    147         stream = transformers.balanced_batch(stream, key='latitude',
    148                                              batch_size=self.config.batch_size,
    149                                              batch_sort_size=self.config.batch_sort_size)
    150         stream = Padding(stream, mask_sources=['latitude', 'longitude'])
    151         stream = transformers.Select(stream, req_vars)
    152         stream = MultiProcessing(stream)
    153 
    154         return stream
    155 
    156     def test(self, req_vars):
    157         stream = TaxiStream('test', data.traintest_ds)
    158 
    159         stream = transformers.window(stream, config.window_size)
    160         
    161         stream = transformers.taxi_add_datetime(stream)
    162         stream = transformers.taxi_remove_test_only_clients(stream)
    163 
    164         stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
    165 
    166         stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
    167         stream = Padding(stream, mask_sources=['latitude', 'longitude'])
    168         stream = transformers.Select(stream, req_vars)
    169         return stream
    170 
    171     def inputs(self):
    172         return {'call_type': tensor.bvector('call_type'),
    173                 'origin_call': tensor.ivector('origin_call'),
    174                 'origin_stand': tensor.bvector('origin_stand'),
    175                 'taxi_id': tensor.wvector('taxi_id'),
    176                 'timestamp': tensor.ivector('timestamp'),
    177                 'day_type': tensor.bvector('day_type'),
    178                 'missing_data': tensor.bvector('missing_data'),
    179                 'latitude': tensor.tensor('latitude'),
    180                 'longitude': tensor.tensor('longitude'),
    181                 'latitude_mask': tensor.matrix('latitude_mask'),
    182                 'longitude_mask': tensor.matrix('longitude_mask'),
    183                 'destination_latitude': tensor.vector('destination_latitude'),
    184                 'destination_longitude': tensor.vector('destination_longitude'),
    185                 'travel_time': tensor.ivector('travel_time'),
    186                 'input_time': tensor.ivector('input_time'),
    187                 'week_of_year': tensor.bvector('week_of_year'),
    188                 'day_of_week': tensor.bvector('day_of_week'),
    189                 'qhour_of_day': tensor.bvector('qhour_of_day')}
    190