stream.py (4579B)
1 from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing, Filter 2 from fuel.streams import DataStream 3 from fuel.schemes import ConstantScheme, ShuffledExampleScheme 4 5 from theano import tensor 6 7 import data 8 from data import transformers 9 from data.hdf5 import TaxiDataset, TaxiStream 10 11 12 class StreamRec(object): 13 def __init__(self, config): 14 self.config = config 15 16 def train(self, req_vars): 17 stream = TaxiDataset('train', data.traintest_ds) 18 19 if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: 20 stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) 21 else: 22 stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) 23 24 if not data.tvt: 25 valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) 26 valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] 27 stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) 28 29 if hasattr(self.config, 'max_splits'): 30 stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) 31 elif not data.tvt: 32 stream = transformers.add_destination(stream) 33 34 if hasattr(self.config, 'train_max_len'): 35 idx = stream.sources.index('latitude') 36 def max_len_filter(x): 37 return len(x[idx]) <= self.config.train_max_len 38 stream = Filter(stream, max_len_filter) 39 40 stream = transformers.TaxiExcludeEmptyTrips(stream) 41 stream = transformers.taxi_add_datetime(stream) 42 stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) 43 44 stream = transformers.balanced_batch(stream, key='latitude', 45 batch_size=self.config.batch_size, 46 batch_sort_size=self.config.batch_sort_size) 47 stream = Padding(stream, mask_sources=['latitude', 'longitude']) 48 stream = transformers.Select(stream, req_vars) 49 stream = MultiProcessing(stream) 50 51 return stream 52 53 def valid(self, req_vars): 54 stream = TaxiStream(data.valid_set, data.valid_ds) 55 56 stream = transformers.taxi_add_datetime(stream) 57 stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) 58 59 stream = transformers.balanced_batch(stream, key='latitude', 60 batch_size=self.config.batch_size, 61 batch_sort_size=self.config.batch_sort_size) 62 stream = Padding(stream, mask_sources=['latitude', 'longitude']) 63 stream = transformers.Select(stream, req_vars) 64 stream = MultiProcessing(stream) 65 66 return stream 67 68 def test(self, req_vars): 69 stream = TaxiStream('test', data.traintest_ds) 70 71 stream = transformers.taxi_add_datetime(stream) 72 stream = transformers.taxi_remove_test_only_clients(stream) 73 74 stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) 75 76 stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) 77 stream = Padding(stream, mask_sources=['latitude', 'longitude']) 78 stream = transformers.Select(stream, req_vars) 79 return stream 80 81 def inputs(self): 82 return {'call_type': tensor.bvector('call_type'), 83 'origin_call': tensor.ivector('origin_call'), 84 'origin_stand': tensor.bvector('origin_stand'), 85 'taxi_id': tensor.wvector('taxi_id'), 86 'timestamp': tensor.ivector('timestamp'), 87 'day_type': tensor.bvector('day_type'), 88 'missing_data': tensor.bvector('missing_data'), 89 'latitude': tensor.matrix('latitude'), 90 'longitude': tensor.matrix('longitude'), 91 'latitude_mask': tensor.matrix('latitude_mask'), 92 'longitude_mask': tensor.matrix('longitude_mask'), 93 'destination_latitude': tensor.vector('destination_latitude'), 94 'destination_longitude': tensor.vector('destination_longitude'), 95 'travel_time': tensor.ivector('travel_time'), 96 'input_time': tensor.ivector('input_time'), 97 'week_of_year': tensor.bvector('week_of_year'), 98 'day_of_week': tensor.bvector('day_of_week'), 99 'qhour_of_day': tensor.bvector('qhour_of_day')} 100