transformers.py (7778B)
1 import datetime 2 3 import numpy 4 import theano 5 6 import fuel 7 8 from fuel.schemes import ConstantScheme 9 from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack, FilterSources 10 11 import data 12 13 fuel.config.default_seed = 123 14 15 def at_least_k(k, v, pad_at_begin, is_longitude): 16 if len(v) == 0: 17 v = numpy.array([data.train_gps_mean[1 if is_longitude else 0]], dtype=theano.config.floatX) 18 if len(v) < k: 19 if pad_at_begin: 20 v = numpy.concatenate((numpy.full((k - len(v),), v[0]), v)) 21 else: 22 v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1]))) 23 return v 24 25 Select = FilterSources 26 27 class TaxiExcludeTrips(Transformer): 28 produces_examples = True 29 30 def __init__(self, stream, exclude_list): 31 super(TaxiExcludeTrips, self).__init__(stream) 32 self.id_trip_id = stream.sources.index('trip_id') 33 self.exclude = {v: True for v in exclude_list} 34 def get_data(self, request=None): 35 if request is not None: raise ValueError 36 while True: 37 data = next(self.child_epoch_iterator) 38 if not data[self.id_trip_id] in self.exclude: break 39 return data 40 41 class TaxiExcludeEmptyTrips(Transformer): 42 produces_examples = True 43 44 def __init__(self, stream): 45 super(TaxiExcludeEmptyTrips, self).__init__(stream) 46 self.latitude = stream.sources.index('latitude') 47 def get_data(self, request=None): 48 if request is not None: raise ValueError 49 while True: 50 data = next(self.child_epoch_iterator) 51 if len(data[self.latitude])>0: break 52 return data 53 54 class TaxiGenerateSplits(Transformer): 55 produces_examples = True 56 57 def __init__(self, data_stream, max_splits=-1): 58 super(TaxiGenerateSplits, self).__init__(data_stream) 59 60 self.sources = data_stream.sources 61 if not data.tvt: 62 self.sources += ('destination_latitude', 'destination_longitude', 'travel_time') 63 self.max_splits = max_splits 64 self.data = None 65 self.splits = [] 66 self.isplit = 0 67 self.id_latitude = data_stream.sources.index('latitude') 68 self.id_longitude = data_stream.sources.index('longitude') 69 70 self.rng = numpy.random.RandomState(fuel.config.default_seed) 71 72 def get_data(self, request=None): 73 if request is not None: 74 raise ValueError 75 while self.isplit >= len(self.splits): 76 self.data = next(self.child_epoch_iterator) 77 self.splits = range(len(self.data[self.id_longitude])) 78 self.rng.shuffle(self.splits) 79 if self.max_splits != -1 and len(self.splits) > self.max_splits: 80 self.splits = self.splits[:self.max_splits] 81 self.isplit = 0 82 83 i = self.isplit 84 self.isplit += 1 85 n = self.splits[i]+1 86 87 r = list(self.data) 88 89 r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX) 90 r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX) 91 92 r = tuple(r) 93 94 if data.tvt: 95 return r 96 else: 97 dlat = numpy.float32(self.data[self.id_latitude][-1]) 98 dlon = numpy.float32(self.data[self.id_longitude][-1]) 99 ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1)) 100 return r + (dlat, dlon, ttime) 101 102 class _taxi_add_first_last_len_helper(object): 103 def __init__(self, k, id_latitude, id_longitude): 104 self.k = k 105 self.id_latitude = id_latitude 106 self.id_longitude = id_longitude 107 def __call__(self, data): 108 first_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k], 109 dtype=theano.config.floatX), 110 numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k], 111 dtype=theano.config.floatX)) 112 last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:], 113 dtype=theano.config.floatX), 114 numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:], 115 dtype=theano.config.floatX)) 116 input_time = (numpy.int32(15 * (len(data[self.id_latitude]) - 1)),) 117 return first_k + last_k + input_time 118 119 def taxi_add_first_last_len(stream, k): 120 fun = _taxi_add_first_last_len_helper(k, stream.sources.index('latitude'), stream.sources.index('longitude')) 121 return Mapping(stream, fun, add_sources=('first_k_latitude', 'first_k_longitude', 'last_k_latitude', 'last_k_longitude', 'input_time')) 122 123 124 class _taxi_add_datetime_helper(object): 125 def __init__(self, key): 126 self.key = key 127 def __call__(self, data): 128 ts = data[self.key] 129 date = datetime.datetime.utcfromtimestamp(ts) 130 yearweek = date.isocalendar()[1] - 1 131 info = (numpy.int8(51 if yearweek == 52 else yearweek), 132 numpy.int8(date.weekday()), 133 numpy.int8(date.hour * 4 + date.minute / 15)) 134 return info 135 136 def taxi_add_datetime(stream): 137 fun = _taxi_add_datetime_helper(stream.sources.index('timestamp')) 138 return Mapping(stream, fun, add_sources=('week_of_year', 'day_of_week', 'qhour_of_day')) 139 140 141 class _balanced_batch_helper(object): 142 def __init__(self, key): 143 self.key = key 144 def __call__(self, data): 145 return data[self.key].shape[0] 146 147 def balanced_batch(stream, key, batch_size, batch_sort_size): 148 stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * batch_sort_size)) 149 comparison = _balanced_batch_helper(stream.sources.index(key)) 150 stream = Mapping(stream, SortMapping(comparison)) 151 stream = Unpack(stream) 152 return Batch(stream, iteration_scheme=ConstantScheme(batch_size)) 153 154 155 class _taxi_remove_test_only_clients_helper(object): 156 def __init__(self, key): 157 self.key = key 158 def __call__(self, x): 159 x = list(x) 160 if x[self.key] >= data.origin_call_train_size: 161 x[self.key] = numpy.int32(0) 162 return tuple(x) 163 164 def taxi_remove_test_only_clients(stream): 165 fun = _taxi_remove_test_only_clients_helper(stream.sources.index('origin_call')) 166 return Mapping(stream, fun) 167 168 169 class _add_destination_helper(object): 170 def __init__(self, latitude, longitude): 171 self.latitude = latitude 172 self.longitude = longitude 173 def __call__(self, data): 174 return (data[self.latitude][-1], data[self.longitude][-1]) 175 176 def add_destination(stream): 177 fun = _add_destination_helper(stream.sources.index('latitude'), stream.sources.index('longitude')) 178 return Mapping(stream, fun, add_sources=('destination_latitude', 'destination_longitude')) 179 180 class _window_helper(object): 181 def __init__(self, latitude, longitude, window_len): 182 self.latitude = latitude 183 self.longitude = longitude 184 self.window_len = window_len 185 def makewindow(self, x): 186 assert len(x.shape) == 1 187 188 if x.shape[0] < self.window_len: 189 x = numpy.concatenate( 190 [numpy.full((self.window_len - x.shape[0],), x[0]), x]) 191 192 y = [x[i: i+x.shape[0]-self.window_len+1][:, None] 193 for i in range(self.window_len)] 194 195 return numpy.concatenate(y, axis=1) 196 197 def __call__(self, data): 198 data = list(data) 199 data[self.latitude] = self.makewindow(data[self.latitude]) 200 data[self.longitude] = self.makewindow(data[self.longitude]) 201 return tuple(data) 202 203 204 def window(stream, window_len): 205 fun = _window_helper(stream.sources.index('latitude'), 206 stream.sources.index('longitude'), 207 window_len) 208 return Mapping(stream, fun) 209