mlp.py (5804B)
1 from theano import tensor 2 import numpy 3 4 import fuel 5 import blocks 6 7 from fuel.transformers import Batch, MultiProcessing, Mapping, SortMapping, Unpack 8 from fuel.streams import DataStream 9 from fuel.schemes import ConstantScheme, ShuffledExampleScheme 10 from blocks.bricks import application, MLP, Rectifier, Initializable 11 12 import data 13 from data import transformers 14 from data.hdf5 import TaxiDataset, TaxiStream 15 from data.cut import TaxiTimeCutScheme 16 from model import ContextEmbedder 17 18 blocks.config.default_seed = 123 19 fuel.config.default_seed = 123 20 21 class FFMLP(Initializable): 22 def __init__(self, config, output_layer=None, **kwargs): 23 super(FFMLP, self).__init__(**kwargs) 24 self.config = config 25 26 self.context_embedder = ContextEmbedder(config) 27 28 output_activation = [] if output_layer is None else [output_layer()] 29 output_dim = [] if output_layer is None else [config.dim_output] 30 self.mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + output_activation, 31 dims=[config.dim_input] + config.dim_hidden + output_dim) 32 33 self.extremities = {'%s_k_%s' % (side, ['latitude', 'longitude'][axis]): axis for side in ['first', 'last'] for axis in [0, 1]} 34 self.inputs = self.context_embedder.inputs + self.extremities.keys() 35 self.children = [ self.context_embedder, self.mlp ] 36 37 def _push_initialization_config(self): 38 self.mlp.weights_init = self.config.mlp_weights_init 39 self.mlp.biases_init = self.config.mlp_biases_init 40 41 @application(outputs=['prediction']) 42 def predict(self, **kwargs): 43 embeddings = tuple(self.context_embedder.apply(**{k: kwargs[k] for k in self.context_embedder.inputs })) 44 extremities = tuple((kwargs[k] - data.train_gps_mean[v]) / data.train_gps_std[v] for k, v in self.extremities.items()) 45 46 inputs = tensor.concatenate(extremities + embeddings, axis=1) 47 outputs = self.mlp.apply(inputs) 48 49 return outputs 50 51 @predict.property('inputs') 52 def predict_inputs(self): 53 return self.inputs 54 55 class UniformGenerator(object): 56 def __init__(self): 57 self.rng = numpy.random.RandomState(123) 58 def __call__(self, *args): 59 return float(self.rng.uniform()) 60 61 class Stream(object): 62 def __init__(self, config): 63 self.config = config 64 65 def train(self, req_vars): 66 stream = TaxiDataset('train', data.traintest_ds) 67 68 if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: 69 stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) 70 else: 71 stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) 72 73 if not data.tvt: 74 valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) 75 valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] 76 stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) 77 78 stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) 79 80 if hasattr(self.config, 'shuffle_batch_size'): 81 stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size)) 82 stream = Mapping(stream, SortMapping(key=UniformGenerator())) 83 stream = Unpack(stream) 84 85 stream = transformers.taxi_add_datetime(stream) 86 stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) 87 stream = transformers.Select(stream, tuple(req_vars)) 88 89 stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) 90 91 stream = MultiProcessing(stream) 92 93 return stream 94 95 def valid(self, req_vars): 96 stream = TaxiStream(data.valid_set, data.valid_ds) 97 98 stream = transformers.taxi_add_datetime(stream) 99 stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) 100 stream = transformers.Select(stream, tuple(req_vars)) 101 return Batch(stream, iteration_scheme=ConstantScheme(1000)) 102 103 def test(self, req_vars): 104 stream = TaxiStream('test', data.traintest_ds) 105 106 stream = transformers.taxi_add_datetime(stream) 107 stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) 108 stream = transformers.taxi_remove_test_only_clients(stream) 109 110 return Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) 111 112 def inputs(self): 113 return {'call_type': tensor.bvector('call_type'), 114 'origin_call': tensor.ivector('origin_call'), 115 'origin_stand': tensor.bvector('origin_stand'), 116 'taxi_id': tensor.wvector('taxi_id'), 117 'timestamp': tensor.ivector('timestamp'), 118 'day_type': tensor.bvector('day_type'), 119 'missing_data': tensor.bvector('missing_data'), 120 'latitude': tensor.matrix('latitude'), 121 'longitude': tensor.matrix('longitude'), 122 'destination_latitude': tensor.vector('destination_latitude'), 123 'destination_longitude': tensor.vector('destination_longitude'), 124 'travel_time': tensor.ivector('travel_time'), 125 'first_k_latitude': tensor.matrix('first_k_latitude'), 126 'first_k_longitude': tensor.matrix('first_k_longitude'), 127 'last_k_latitude': tensor.matrix('last_k_latitude'), 128 'last_k_longitude': tensor.matrix('last_k_longitude'), 129 'input_time': tensor.ivector('input_time'), 130 'week_of_year': tensor.bvector('week_of_year'), 131 'day_of_week': tensor.bvector('day_of_week'), 132 'qhour_of_day': tensor.bvector('qhour_of_day')}