taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

rnn_lag_tgtcls.py (1253B)


      1 import numpy
      2 import theano
      3 from theano import tensor
      4 from blocks.bricks.base import lazy
      5 from blocks.bricks import Softmax
      6 
      7 from model.rnn import RNN, Stream
      8 
      9 
     10 class Model(RNN):
     11     @lazy()
     12     def __init__(self, config, **kwargs):
     13         super(Model, self).__init__(config, rec_input_len=4, output_dim=config.tgtcls.shape[0], **kwargs)
     14         self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), name='classes')
     15         self.softmax = Softmax()
     16         self.sequences.extend(['latitude_lag', 'longitude_lag'])
     17         self.children.append(self.softmax)
     18 
     19     def before_predict_all(self, kwargs):
     20         super(Model, self).before_predict_all(kwargs)
     21         kwargs['latitude_lag'] = tensor.extra_ops.repeat(kwargs['latitude'], 2, axis=0)
     22         kwargs['longitude_lag'] = tensor.extra_ops.repeat(kwargs['longitude'], 2, axis=0)
     23 
     24     def process_rto(self, rto):
     25         return tensor.dot(self.softmax.apply(rto), self.classes)
     26 
     27     def rec_input(self, latitude, longitude, latitude_lag, longitude_lag, **kwargs):
     28         return (tensor.shape_padright(latitude),
     29                 tensor.shape_padright(longitude),
     30                 tensor.shape_padright(latitude_lag),
     31                 tensor.shape_padright(longitude_lag))