taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

memory_network.py (14517B)


      1 
      2 from theano import tensor
      3 
      4 from fuel.transformers import Batch, MultiProcessing, Merge, Padding
      5 from fuel.streams import DataStream
      6 from fuel.schemes import ConstantScheme, ShuffledExampleScheme, SequentialExampleScheme
      7 from blocks.bricks import application, MLP, Rectifier, Initializable, Softmax
      8 
      9 import data
     10 from data import transformers
     11 from data.cut import TaxiTimeCutScheme
     12 from data.hdf5 import TaxiDataset, TaxiStream
     13 import error
     14 from model import ContextEmbedder
     15 
     16 class MemoryNetworkBase(Initializable):
     17     def __init__(self, config, prefix_encoder, candidate_encoder, **kwargs):
     18         super(MemoryNetworkBase, self).__init__(**kwargs)
     19 
     20         self.prefix_encoder = prefix_encoder
     21         self.candidate_encoder = candidate_encoder
     22         self.config = config
     23 
     24         self.softmax = Softmax()
     25         self.children = [ self.softmax, prefix_encoder, candidate_encoder ]
     26 
     27         self.inputs = self.prefix_encoder.apply.inputs \
     28                       + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] \
     29                       + ['candidate_destination_latitude', 'candidate_destination_longitude']
     30 
     31     def candidate_destination(self, **kwargs):
     32         return tensor.concatenate(
     33                 (tensor.shape_padright(kwargs['candidate_destination_latitude']),
     34                  tensor.shape_padright(kwargs['candidate_destination_longitude'])),
     35                 axis=1)
     36 
     37     @application(outputs=['cost'])
     38     def cost(self, **kwargs):
     39         y_hat = self.predict(**kwargs)
     40         y = tensor.concatenate((kwargs['destination_latitude'][:, None],
     41                                 kwargs['destination_longitude'][:, None]), axis=1)
     42 
     43         return error.erdist(y_hat, y).mean()
     44 
     45     @application(outputs=['destination'])
     46     def predict(self, **kwargs):
     47         prefix_representation = self.prefix_encoder.apply(**{ x: kwargs[x] for x in self.prefix_encoder.apply.inputs })
     48         candidate_representation = self.candidate_encoder.apply(**{ x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs })
     49 
     50         if self.config.normalize_representation:
     51             prefix_representation = prefix_representation \
     52                     / tensor.sqrt((prefix_representation ** 2).sum(axis=1, keepdims=True))
     53             candidate_representation = candidate_representation \
     54                     / tensor.sqrt((candidate_representation ** 2).sum(axis=1, keepdims=True))
     55 
     56         similarity_score = tensor.dot(prefix_representation, candidate_representation.T)
     57         similarity = self.softmax.apply(similarity_score)
     58 
     59         return tensor.dot(similarity, self.candidate_destination(**kwargs))
     60 
     61     @predict.property('inputs')
     62     def predict_inputs(self):
     63         return self.inputs
     64 
     65     @cost.property('inputs')
     66     def cost_inputs(self):
     67         return self.inputs + ['destination_latitude', 'destination_longitude']
     68 
     69 class StreamBase(object):
     70     def __init__(self, config):
     71         self.config = config
     72 
     73         self.prefix_inputs = [
     74                 ('call_type', tensor.bvector),
     75                 ('origin_call', tensor.ivector),
     76                 ('origin_stand', tensor.bvector),
     77                 ('taxi_id', tensor.wvector),
     78                 ('timestamp', tensor.ivector),
     79                 ('day_type', tensor.bvector),
     80                 ('missing_data', tensor.bvector),
     81                 ('latitude', tensor.matrix),
     82                 ('longitude', tensor.matrix),
     83                 ('destination_latitude', tensor.vector),
     84                 ('destination_longitude', tensor.vector),
     85                 ('travel_time', tensor.ivector),
     86                 ('input_time', tensor.ivector),
     87                 ('week_of_year', tensor.bvector),
     88                 ('day_of_week', tensor.bvector),
     89                 ('qhour_of_day', tensor.bvector)
     90             ]
     91         self.candidate_inputs = self.prefix_inputs
     92 
     93     def inputs(self):
     94         prefix_inputs = { name: constructor(name)
     95                         for name, constructor in self.prefix_inputs }
     96         candidate_inputs = { 'candidate_'+name: constructor('candidate_'+name)
     97                              for name, constructor in self.candidate_inputs }
     98         return dict(prefix_inputs.items() + candidate_inputs.items())
     99 
    100     @property
    101     def valid_dataset(self):
    102         return TaxiDataset(data.valid_set, data.valid_ds)
    103 
    104     @property
    105     def valid_trips_ids(self):
    106         valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
    107         return valid.get_data(None, slice(0, valid.num_examples))[0]
    108 
    109     @property
    110     def train_dataset(self):
    111         return TaxiDataset('train', data.traintest_ds)
    112 
    113     @property
    114     def test_dataset(self):
    115         return TaxiDataset('test', data.traintest_ds)
    116 
    117 
    118 class StreamSimple(StreamBase):
    119     def __init__(self, config):
    120         super(StreamSimple, self).__init__(config)
    121 
    122         self.prefix_inputs += [
    123                 ('first_k_latitude', tensor.matrix),
    124                 ('first_k_longitude', tensor.matrix),
    125                 ('last_k_latitude', tensor.matrix),
    126                 ('last_k_longitude', tensor.matrix),
    127         ]
    128         self.candidate_inputs = self.prefix_inputs
    129 
    130     def candidate_stream(self, n_candidates):
    131         candidate_stream = DataStream(self.train_dataset,
    132                                       iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
    133         if not data.tvt:
    134             candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
    135         candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
    136         candidate_stream = transformers.taxi_add_datetime(candidate_stream)
    137         candidate_stream = transformers.taxi_add_first_last_len(candidate_stream,
    138                                                                 self.config.n_begin_end_pts)
    139         if not data.tvt:
    140             candidate_stream = transformers.add_destination(candidate_stream)
    141 
    142         return Batch(candidate_stream,
    143                      iteration_scheme=ConstantScheme(n_candidates))
    144 
    145     def train(self, req_vars):
    146         prefix_stream = DataStream(self.train_dataset,
    147                                    iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
    148 
    149         if not data.tvt:
    150             prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
    151         prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
    152         prefix_stream = transformers.TaxiGenerateSplits(prefix_stream,
    153                                                         max_splits=self.config.max_splits)
    154         prefix_stream = transformers.taxi_add_datetime(prefix_stream)
    155         prefix_stream = transformers.taxi_add_first_last_len(prefix_stream,
    156                                                              self.config.n_begin_end_pts)
    157         prefix_stream = Batch(prefix_stream,
    158                               iteration_scheme=ConstantScheme(self.config.batch_size))
    159 
    160         candidate_stream = self.candidate_stream(self.config.train_candidate_size)
    161 
    162         sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
    163         stream = Merge((prefix_stream, candidate_stream), sources)
    164         stream = transformers.Select(stream, tuple(req_vars))
    165         stream = MultiProcessing(stream)
    166         return stream
    167 
    168     def valid(self, req_vars):
    169         prefix_stream = DataStream(
    170                            self.valid_dataset,
    171                            iteration_scheme=SequentialExampleScheme(self.valid_dataset.num_examples))
    172         prefix_stream = transformers.taxi_add_datetime(prefix_stream)
    173         prefix_stream = transformers.taxi_add_first_last_len(prefix_stream,
    174                                                              self.config.n_begin_end_pts)
    175         prefix_stream = Batch(prefix_stream,
    176                               iteration_scheme=ConstantScheme(self.config.batch_size))
    177 
    178         candidate_stream = self.candidate_stream(self.config.valid_candidate_size)
    179 
    180         sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
    181         stream = Merge((prefix_stream, candidate_stream), sources)
    182         stream = transformers.Select(stream, tuple(req_vars))
    183         stream = MultiProcessing(stream)
    184         return stream
    185 
    186     def test(self, req_vars):
    187         prefix_stream = DataStream(
    188                            self.test_dataset,
    189                            iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples))
    190         prefix_stream = transformers.taxi_add_datetime(prefix_stream)
    191         prefix_stream = transformers.taxi_add_first_last_len(prefix_stream,
    192                                                              self.config.n_begin_end_pts)
    193 
    194         if not data.tvt:
    195             prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)
    196 
    197         prefix_stream = Batch(prefix_stream,
    198                               iteration_scheme=ConstantScheme(self.config.batch_size))
    199 
    200         candidate_stream = self.candidate_stream(self.config.test_candidate_size)
    201 
    202         sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
    203         stream = Merge((prefix_stream, candidate_stream), sources)
    204         stream = transformers.Select(stream, tuple(req_vars))
    205         stream = MultiProcessing(stream)
    206         return stream
    207 
    208 class StreamRecurrent(StreamBase):
    209     def __init__(self, config):
    210         super(StreamRecurrent, self).__init__(config)
    211 
    212         self.prefix_inputs += [
    213                 ('latitude_mask', tensor.matrix),
    214                 ('longitude_mask', tensor.matrix),
    215         ]
    216         self.candidate_inputs = self.prefix_inputs
    217 
    218     def candidate_stream(self, n_candidates, sortmap=True):
    219         candidate_stream = DataStream(self.train_dataset,
    220                                       iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
    221         if not data.tvt:
    222             candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
    223         candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
    224         candidate_stream = transformers.taxi_add_datetime(candidate_stream)
    225 
    226         if not data.tvt:
    227             candidate_stream = transformers.add_destination(candidate_stream)
    228 
    229         if sortmap:
    230             candidate_stream = transformers.balanced_batch(candidate_stream,
    231                                                            key='latitude',
    232                                                            batch_size=n_candidates,
    233                                                            batch_sort_size=self.config.batch_sort_size)
    234         else:
    235             candidate_stream = Batch(candidate_stream,
    236                                      iteration_scheme=ConstantScheme(n_candidates))
    237 
    238         candidate_stream = Padding(candidate_stream,
    239                                    mask_sources=['latitude', 'longitude'])
    240 
    241         return candidate_stream
    242 
    243     def train(self, req_vars):
    244         prefix_stream = DataStream(self.train_dataset,
    245                                    iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
    246 
    247         if not data.tvt:
    248             prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
    249         prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
    250         prefix_stream = transformers.TaxiGenerateSplits(prefix_stream,
    251                                                         max_splits=self.config.max_splits)
    252 
    253         prefix_stream = transformers.taxi_add_datetime(prefix_stream)
    254 
    255         prefix_stream = transformers.balanced_batch(prefix_stream,
    256                                                     key='latitude',
    257                                                     batch_size=self.config.batch_size,
    258                                                     batch_sort_size=self.config.batch_sort_size)
    259 
    260         prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude'])
    261 
    262         candidate_stream = self.candidate_stream(self.config.train_candidate_size)
    263 
    264         sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
    265         stream = Merge((prefix_stream, candidate_stream), sources)
    266 
    267         stream = transformers.Select(stream, tuple(req_vars))
    268         # stream = MultiProcessing(stream)
    269         return stream
    270 
    271     def valid(self, req_vars):
    272         prefix_stream = DataStream(
    273                            self.valid_dataset,
    274                            iteration_scheme=SequentialExampleScheme(self.valid_dataset.num_examples))
    275 
    276         #prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
    277 
    278         prefix_stream = transformers.taxi_add_datetime(prefix_stream)
    279 
    280         prefix_stream = transformers.balanced_batch(prefix_stream,
    281                                                     key='latitude',
    282                                                     batch_size=self.config.batch_size,
    283                                                     batch_sort_size=self.config.batch_sort_size)
    284 
    285         prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude'])
    286 
    287         candidate_stream = self.candidate_stream(self.config.valid_candidate_size)
    288 
    289         sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
    290         stream = Merge((prefix_stream, candidate_stream), sources)
    291 
    292         stream = transformers.Select(stream, tuple(req_vars))
    293         # stream = MultiProcessing(stream)
    294 
    295         return stream
    296 
    297     def test(self, req_vars):
    298         prefix_stream = DataStream(
    299                            self.test_dataset,
    300                            iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples))
    301 
    302         prefix_stream = transformers.taxi_add_datetime(prefix_stream)
    303         if not data.tvt:
    304             prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)
    305 
    306         prefix_stream = Batch(prefix_stream,
    307                               iteration_scheme=ConstantScheme(self.config.batch_size))
    308         prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude'])
    309 
    310         candidate_stream = self.candidate_stream(self.config.test_candidate_size, False)
    311 
    312         sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
    313         stream = Merge((prefix_stream, candidate_stream), sources)
    314 
    315         stream = transformers.Select(stream, tuple(req_vars))
    316         # stream = MultiProcessing(stream)
    317 
    318         return stream