mlp_emb.py (5366B)
1 from theano import tensor 2 3 from fuel.transformers import Batch, MultiProcessing 4 from fuel.streams import DataStream 5 from fuel.schemes import ConstantScheme, ShuffledExampleScheme 6 from blocks.bricks import application, MLP, Rectifier, Initializable, Identity 7 8 import error 9 import data 10 from data import transformers 11 from data.hdf5 import TaxiDataset, TaxiStream 12 from data.cut import TaxiTimeCutScheme 13 from model import ContextEmbedder 14 15 16 class Model(Initializable): 17 def __init__(self, config, **kwargs): 18 super(Model, self).__init__(**kwargs) 19 self.config = config 20 21 self.context_embedder = ContextEmbedder(config) 22 self.mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], 23 dims=[config.dim_input] + config.dim_hidden + [config.dim_output]) 24 25 self.inputs = self.context_embedder.inputs # + self.extremities.keys() 26 self.children = [ self.context_embedder, self.mlp ] 27 28 def _push_initialization_config(self): 29 self.mlp.weights_init = self.config.mlp_weights_init 30 self.mlp.biases_init = self.config.mlp_biases_init 31 32 @application(outputs=['destination']) 33 def predict(self, **kwargs): 34 embeddings = tuple(self.context_embedder.apply(**{k: kwargs[k] for k in self.context_embedder.inputs })) 35 36 inputs = tensor.concatenate(embeddings, axis=1) 37 outputs = self.mlp.apply(inputs) 38 39 if self.config.output_mode == "destination": 40 return data.train_gps_std * outputs + data.train_gps_mean 41 elif self.config.dim_output == "clusters": 42 return tensor.dot(outputs, self.classes) 43 44 @predict.property('inputs') 45 def predict_inputs(self): 46 return self.inputs 47 48 @application(outputs=['cost']) 49 def cost(self, **kwargs): 50 y_hat = self.predict(**kwargs) 51 y = tensor.concatenate((kwargs['destination_latitude'][:, None], 52 kwargs['destination_longitude'][:, None]), axis=1) 53 54 return error.erdist(y_hat, y).mean() 55 56 @cost.property('inputs') 57 def cost_inputs(self): 58 return self.inputs + ['destination_latitude', 'destination_longitude'] 59 60 61 class Stream(object): 62 def __init__(self, config): 63 self.config = config 64 65 def train(self, req_vars): 66 valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) 67 valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] 68 69 stream = TaxiDataset('train') 70 71 if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: 72 stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) 73 else: 74 stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) 75 76 stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) 77 stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) 78 79 stream = transformers.taxi_add_datetime(stream) 80 # stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) 81 stream = transformers.Select(stream, tuple(req_vars)) 82 83 stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) 84 85 stream = MultiProcessing(stream) 86 87 return stream 88 89 def valid(self, req_vars): 90 stream = TaxiStream(self.config.valid_set, 'valid.hdf5') 91 92 stream = transformers.taxi_add_datetime(stream) 93 # stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) 94 stream = transformers.Select(stream, tuple(req_vars)) 95 return Batch(stream, iteration_scheme=ConstantScheme(1000)) 96 97 def test(self, req_vars): 98 stream = TaxiStream('test') 99 100 stream = transformers.taxi_add_datetime(stream) 101 # stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) 102 stream = transformers.taxi_remove_test_only_clients(stream) 103 104 return Batch(stream, iteration_scheme=ConstantScheme(1)) 105 106 def inputs(self): 107 return {'call_type': tensor.bvector('call_type'), 108 'origin_call': tensor.ivector('origin_call'), 109 'origin_stand': tensor.bvector('origin_stand'), 110 'taxi_id': tensor.wvector('taxi_id'), 111 'timestamp': tensor.ivector('timestamp'), 112 'day_type': tensor.bvector('day_type'), 113 'missing_data': tensor.bvector('missing_data'), 114 'latitude': tensor.matrix('latitude'), 115 'longitude': tensor.matrix('longitude'), 116 'destination_latitude': tensor.vector('destination_latitude'), 117 'destination_longitude': tensor.vector('destination_longitude'), 118 'travel_time': tensor.ivector('travel_time'), 119 'first_k_latitude': tensor.matrix('first_k_latitude'), 120 'first_k_longitude': tensor.matrix('first_k_longitude'), 121 'last_k_latitude': tensor.matrix('last_k_latitude'), 122 'last_k_longitude': tensor.matrix('last_k_longitude'), 123 'input_time': tensor.ivector('input_time'), 124 'week_of_year': tensor.bvector('week_of_year'), 125 'day_of_week': tensor.bvector('day_of_week'), 126 'qhour_of_day': tensor.bvector('qhour_of_day')}