taxi

Winning entry to the Kaggle taxi competition
git clone https://esimon.eu/repos/taxi.git
Log | Files | Refs | README

transformers.py (7778B)


      1 import datetime
      2 
      3 import numpy
      4 import theano
      5 
      6 import fuel
      7 
      8 from fuel.schemes import ConstantScheme
      9 from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack, FilterSources
     10 
     11 import data
     12 
     13 fuel.config.default_seed = 123
     14 
     15 def at_least_k(k, v, pad_at_begin, is_longitude):
     16     if len(v) == 0:
     17         v = numpy.array([data.train_gps_mean[1 if is_longitude else 0]], dtype=theano.config.floatX)
     18     if len(v) < k:
     19         if pad_at_begin:
     20             v = numpy.concatenate((numpy.full((k - len(v),), v[0]), v))
     21         else:
     22             v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1])))
     23     return v
     24 
     25 Select = FilterSources
     26 
     27 class TaxiExcludeTrips(Transformer):
     28     produces_examples = True
     29 
     30     def __init__(self, stream, exclude_list):
     31         super(TaxiExcludeTrips, self).__init__(stream)
     32         self.id_trip_id = stream.sources.index('trip_id')
     33         self.exclude = {v: True for v in exclude_list}
     34     def get_data(self, request=None):
     35         if request is not None: raise ValueError
     36         while True:
     37             data = next(self.child_epoch_iterator)
     38             if not data[self.id_trip_id] in self.exclude: break
     39         return data
     40 
     41 class TaxiExcludeEmptyTrips(Transformer):
     42     produces_examples = True
     43 
     44     def __init__(self, stream):
     45         super(TaxiExcludeEmptyTrips, self).__init__(stream)
     46         self.latitude = stream.sources.index('latitude')
     47     def get_data(self, request=None):
     48         if request is not None: raise ValueError
     49         while True:
     50             data = next(self.child_epoch_iterator)
     51             if len(data[self.latitude])>0: break
     52         return data
     53         
     54 class TaxiGenerateSplits(Transformer):
     55     produces_examples = True
     56 
     57     def __init__(self, data_stream, max_splits=-1):
     58         super(TaxiGenerateSplits, self).__init__(data_stream)
     59 
     60         self.sources = data_stream.sources 
     61         if not data.tvt:
     62             self.sources += ('destination_latitude', 'destination_longitude', 'travel_time')
     63         self.max_splits = max_splits
     64         self.data = None
     65         self.splits = []
     66         self.isplit = 0
     67         self.id_latitude = data_stream.sources.index('latitude')
     68         self.id_longitude = data_stream.sources.index('longitude')
     69 
     70         self.rng = numpy.random.RandomState(fuel.config.default_seed)
     71 
     72     def get_data(self, request=None):
     73         if request is not None:
     74             raise ValueError
     75         while self.isplit >= len(self.splits):
     76             self.data = next(self.child_epoch_iterator)
     77             self.splits = range(len(self.data[self.id_longitude]))
     78             self.rng.shuffle(self.splits)
     79             if self.max_splits != -1 and len(self.splits) > self.max_splits:
     80                 self.splits = self.splits[:self.max_splits]
     81             self.isplit = 0
     82         
     83         i = self.isplit
     84         self.isplit += 1
     85         n = self.splits[i]+1
     86 
     87         r = list(self.data)
     88 
     89         r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX)
     90         r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX)
     91 
     92         r = tuple(r)
     93 
     94         if data.tvt:
     95             return r
     96         else:
     97             dlat = numpy.float32(self.data[self.id_latitude][-1])
     98             dlon = numpy.float32(self.data[self.id_longitude][-1])
     99             ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1))
    100             return r + (dlat, dlon, ttime)
    101 
    102 class _taxi_add_first_last_len_helper(object):
    103     def __init__(self, k, id_latitude, id_longitude):
    104         self.k = k
    105         self.id_latitude = id_latitude
    106         self.id_longitude = id_longitude
    107     def __call__(self, data):
    108         first_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k],
    109                                dtype=theano.config.floatX),
    110                    numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k],
    111                                dtype=theano.config.floatX))
    112         last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:],
    113                             dtype=theano.config.floatX),
    114                   numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
    115                               dtype=theano.config.floatX))
    116         input_time = (numpy.int32(15 * (len(data[self.id_latitude]) - 1)),)
    117         return first_k + last_k + input_time
    118 
    119 def taxi_add_first_last_len(stream, k):
    120     fun = _taxi_add_first_last_len_helper(k, stream.sources.index('latitude'), stream.sources.index('longitude'))
    121     return Mapping(stream, fun, add_sources=('first_k_latitude', 'first_k_longitude', 'last_k_latitude', 'last_k_longitude', 'input_time'))
    122 
    123 
    124 class _taxi_add_datetime_helper(object):
    125     def __init__(self, key):
    126         self.key = key
    127     def __call__(self, data):
    128         ts = data[self.key]
    129         date = datetime.datetime.utcfromtimestamp(ts)
    130         yearweek = date.isocalendar()[1] - 1
    131         info = (numpy.int8(51 if yearweek == 52 else yearweek),
    132                 numpy.int8(date.weekday()),
    133                 numpy.int8(date.hour * 4 + date.minute / 15))
    134         return info
    135 
    136 def taxi_add_datetime(stream):
    137     fun = _taxi_add_datetime_helper(stream.sources.index('timestamp'))
    138     return Mapping(stream, fun, add_sources=('week_of_year', 'day_of_week', 'qhour_of_day'))
    139 
    140 
    141 class _balanced_batch_helper(object):
    142     def __init__(self, key):
    143         self.key = key
    144     def __call__(self, data):
    145         return data[self.key].shape[0]
    146 
    147 def balanced_batch(stream, key, batch_size, batch_sort_size):
    148     stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * batch_sort_size))
    149     comparison = _balanced_batch_helper(stream.sources.index(key))
    150     stream = Mapping(stream, SortMapping(comparison))
    151     stream = Unpack(stream)
    152     return Batch(stream, iteration_scheme=ConstantScheme(batch_size))
    153 
    154 
    155 class _taxi_remove_test_only_clients_helper(object):
    156     def __init__(self, key):
    157         self.key = key
    158     def __call__(self, x):
    159         x = list(x)
    160         if x[self.key] >= data.origin_call_train_size:
    161             x[self.key] = numpy.int32(0)
    162         return tuple(x)
    163 
    164 def taxi_remove_test_only_clients(stream):
    165     fun = _taxi_remove_test_only_clients_helper(stream.sources.index('origin_call'))
    166     return Mapping(stream, fun)
    167 
    168 
    169 class _add_destination_helper(object):
    170     def __init__(self, latitude, longitude):
    171         self.latitude = latitude
    172         self.longitude = longitude
    173     def __call__(self, data):
    174         return (data[self.latitude][-1], data[self.longitude][-1])
    175 
    176 def add_destination(stream):
    177     fun = _add_destination_helper(stream.sources.index('latitude'), stream.sources.index('longitude'))
    178     return Mapping(stream, fun, add_sources=('destination_latitude', 'destination_longitude'))
    179 
    180 class _window_helper(object):
    181     def __init__(self, latitude, longitude, window_len):
    182         self.latitude = latitude
    183         self.longitude = longitude
    184         self.window_len = window_len
    185     def makewindow(self, x):
    186         assert len(x.shape) == 1
    187 
    188         if x.shape[0] < self.window_len:
    189             x = numpy.concatenate(
    190                 [numpy.full((self.window_len - x.shape[0],), x[0]), x])
    191             
    192         y = [x[i: i+x.shape[0]-self.window_len+1][:, None]
    193              for i in range(self.window_len)]
    194 
    195         return numpy.concatenate(y, axis=1)
    196 
    197     def __call__(self, data):
    198         data = list(data)
    199         data[self.latitude] = self.makewindow(data[self.latitude])
    200         data[self.longitude] = self.makewindow(data[self.longitude])
    201         return tuple(data)
    202 
    203 
    204 def window(stream, window_len):
    205     fun = _window_helper(stream.sources.index('latitude'),
    206                          stream.sources.index('longitude'),
    207                          window_len)
    208     return Mapping(stream, fun)
    209