commit b44f7113c64568a20c9f93ca17577f17d7695dcb
parent e6215fdd8b64c91210268cc8e929b19c22a53660
Author: Étienne Simon <esimon@esimon.eu>
Date: Thu, 11 Jun 2015 16:26:29 -0400
Use Mapping instead of extending Transformer
Diffstat:
2 files changed, 93 insertions(+), 52 deletions(-)
diff --git a/data/transformers.py b/data/transformers.py
@@ -3,7 +3,8 @@ import random
import numpy
import theano
-from fuel.transformers import Transformer
+from fuel.schemes import ConstantScheme
+from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack
import data
@@ -30,6 +31,29 @@ class Select(Transformer):
raise ValueError
data=next(self.child_epoch_iterator)
return [data[id] for id in self.ids]
+
+class TaxiExcludeTrips(Transformer):
+ def __init__(self, stream, exclude_list):
+ super(TaxiExcludeTrips, self).__init__(stream)
+ self.id_trip_id = stream.sources.index('trip_id')
+ self.exclude = {v: True for v in exclude_list}
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ while True:
+ data = next(self.child_epoch_iterator)
+ if not data[self.id_trip_id] in self.exclude: break
+ return data
+
+class TaxiExcludeEmptyTrips(Transformer):
+ def __init__(self, stream):
+ super(TaxiExcludeEmptyTrips, self).__init__(stream)
+ self.latitude = stream.sources.index('latitude')
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ while True:
+ data = next(self.child_epoch_iterator)
+ if len(data[self.latitude])>0: break
+ return data
class TaxiGenerateSplits(Transformer):
def __init__(self, data_stream, max_splits=-1):
@@ -68,18 +92,13 @@ class TaxiGenerateSplits(Transformer):
return tuple(r + [dlat, dlon, ttime])
-class TaxiAddFirstLastLen(Transformer):
- def __init__(self, k, stream):
- super(TaxiAddFirstLastLen, self).__init__(stream)
- self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude',
- 'last_k_latitude', 'last_k_longitude',
- 'input_time')
- self.id_latitude = stream.sources.index('latitude')
- self.id_longitude = stream.sources.index('longitude')
+
+class _taxi_add_first_last_len_helper(object):
+ def __init__(self, k, latitude, longitude):
self.k = k
- def get_data(self, request=None):
- if request is not None: raise ValueError
- data = next(self.child_epoch_iterator)
+ self.id_latitude = id_latitude
+ self.id_longitude = id_longitude
+ def __call__(self, data):
first_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k],
dtype=theano.config.floatX),
numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k],
@@ -89,43 +108,65 @@ class TaxiAddFirstLastLen(Transformer):
numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
dtype=theano.config.floatX))
input_time = (numpy.int32(15 * (len(data[self.id_latitude]) - 1)),)
- return data + first_k + last_k + input_time
+ return first_k + last_k + input_time
-class TaxiAddDateTime(Transformer):
- def __init__(self, stream):
- super(TaxiAddDateTime, self).__init__(stream)
- self.sources = stream.sources + ('week_of_year', 'day_of_week', 'qhour_of_day')
- self.id_timestamp = stream.sources.index('timestamp')
- def get_data(self, request=None):
- if request is not None: raise ValueError
- data = next(self.child_epoch_iterator)
- ts = data[self.id_timestamp]
+def taxi_add_first_last_len(stream, k):
+ fun = _taxi_add_first_last_len_helper(k, stream.sources.index('latitude'), stream.sources.index('longitude'))
+ return Mapping(stream, fun, add_sources=('first_k_latitude', 'first_k_longitude', 'last_k_latitude', 'last_k_longitude', 'input_time'))
+
+
+class _taxi_add_datetime_helper(object):
+ def __init__(self, key):
+ self.key = key
+ def __call__(self, data):
+ ts = data[self.key]
date = datetime.datetime.utcfromtimestamp(ts)
yearweek = date.isocalendar()[1] - 1
info = (numpy.int8(51 if yearweek == 52 else yearweek),
numpy.int8(date.weekday()),
numpy.int8(date.hour * 4 + date.minute / 15))
- return data + info
+ return info
+
+def taxi_add_datetime(stream):
+ fun = _taxi_add_datetime_helper(stream.sources.index('timestamp'))
+ return Mapping(stream, fun, add_sources=('week_of_year', 'day_of_week', 'qhour_of_day'))
+
+
+class _balanced_batch_helper(object):
+ def __init__(self, key):
+ self.key = key
+ def __call__(self, data):
+ return len(data[self.key])
+
+def balanced_batch(stream, key, batch_size, batch_sort_size):
+ stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * batch_sort_size))
+ comparison = _balanced_batch_helper(stream.sources.index(key))
+ stream = Mapping(stream, SortMapping(comparison))
+ stream = Unpack(stream)
+ return Batch(stream, iteration_scheme=ConstantScheme(batch_size))
+
+
+class _taxi_remove_test_only_clients_helper(object):
+ def __init__(self, key):
+ self.key = key
+ def __call__(self, x):
+ x = list(x)
+ if x[self.key] >= data.origin_call_train_size:
+ x[self.key] = numpy.int32(0)
+ return tuple(x)
-class TaxiExcludeTrips(Transformer):
- def __init__(self, exclude_list, stream):
- super(TaxiExcludeTrips, self).__init__(stream)
- self.id_trip_id = stream.sources.index('trip_id')
- self.exclude = {v: True for v in exclude_list}
- def get_data(self, request=None):
- if request is not None: raise ValueError
- while True:
- data = next(self.child_epoch_iterator)
- if not data[self.id_trip_id] in self.exclude: break
- return data
+def taxi_remove_test_only_clients(stream):
+ fun = _taxi_remove_test_only_clients_helper(stream.sources.index('origin_call'))
+ return Mapping(stream, fun)
-class TaxiRemoveTestOnlyClients(Transformer):
- def __init__(self, stream):
- super(TaxiRemoveTestOnlyClients, self).__init__(stream)
- self.id_origin_call = stream.sources.index('origin_call')
- def get_data(self, request=None):
- if request is not None: raise ValueError
- x = list(next(self.child_epoch_iterator))
- if x[self.id_origin_call] >= data.origin_call_train_size:
- x[self.id_origin_call] = numpy.int32(0)
- return tuple(x)
+
+class _add_destination_helper(object):
+ def __init__(self, latitude, longitude):
+ self.latitude = latitude
+ self.longitude = longitude
+ def __call__(self, data):
+ return (data[self.latitude][-1], data[self.longitude][-1])
+
+def add_destination(stream):
+ fun = _add_destination_helper(stream.sources.index('latitude'), stream.sources.index('longitude'))
+ return Mapping(stream, fun, add_sources=('destination_latitude', 'destination_longitude'))
diff --git a/model/mlp.py b/model/mlp.py
@@ -61,11 +61,11 @@ class Stream(object):
else:
stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
- stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream)
+ stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
- stream = transformers.TaxiAddDateTime(stream)
- stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
stream = transformers.Select(stream, tuple(req_vars))
stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
@@ -77,17 +77,17 @@ class Stream(object):
def valid(self, req_vars):
stream = TaxiStream(self.config.valid_set, 'valid.hdf5')
- stream = transformers.TaxiAddDateTime(stream)
- stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
stream = transformers.Select(stream, tuple(req_vars))
return Batch(stream, iteration_scheme=ConstantScheme(1000))
def test(self, req_vars):
stream = TaxiStream('test')
- stream = transformers.TaxiAddDateTime(stream)
- stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)
- stream = transformers.TaxiRemoveTestOnlyClients(stream)
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
+ stream = transformers.taxi_remove_test_only_clients(stream)
return Batch(stream, iteration_scheme=ConstantScheme(1))