bidirectional_tgtcls_window.py (8606B)
1 from model.bidirectional import SegregatedBidirectional 2 3 4 class Model(Initializable): 5 @lazy() 6 def __init__(self, config, output_dim=2, **kwargs): 7 super(Model, self).__init__(**kwargs) 8 self.config = config 9 10 self.context_embedder = ContextEmbedder(config) 11 12 act = config.rec_activation() if hasattr(config, 'rec_activation') else None 13 self.rec = SegregatedBidirectional(LSTM(dim=config.hidden_state_dim, activation=act, 14 name='recurrent')) 15 16 self.fwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'], 17 prototype=Linear(), name='fwd_fork') 18 self.bkwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'], 19 prototype=Linear(), name='bkwd_fork') 20 21 rto_in = config.hidden_state_dim * 2 + sum(x[2] for x in config.dim_embeddings) 22 self.rec_to_output = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], 23 dims=[rto_in] + config.dim_hidden + [output_dim]) 24 25 self.softmax = Softmax() 26 27 self.sequences = ['latitude', 'latitude_mask', 'longitude'] 28 self.inputs = self.sequences + self.context_embedder.inputs 29 30 self.children = [ self.context_embedder, self.fwd_fork, self.bkwd_fork, 31 self.rec, self.rec_to_output, self.softmax ] 32 33 self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), 34 name='classes') 35 36 def _push_allocation_config(self): 37 for i, fork in enumerate([self.fwd_fork, self.bkwd_fork]): 38 fork.input_dim = 2 * self.config.window_size 39 fork.output_dims = [ self.rec.children[i].get_dim(name) 40 for name in fork.output_names ] 41 42 def _push_initialization_config(self): 43 for brick in [self.fwd_fork, self.bkwd_fork, self.rec, self.rec_to_output]: 44 brick.weights_init = self.config.weights_init 45 brick.biases_init = self.config.biases_init 46 47 def process_outputs(self, outputs): 48 return tensor.dot(self.softmax.apply(outputs), self.classes) 49 50 @application(outputs=['destination']) 51 def predict(self, latitude, longitude, latitude_mask, **kwargs): 52 latitude = (latitude.dimshuffle(1, 0, 2) - data.train_gps_mean[0]) / data.train_gps_std[0] 53 longitude = (longitude.dimshuffle(1, 0, 2) - data.train_gps_mean[1]) / data.train_gps_std[1] 54 latitude_mask = latitude_mask.T 55 56 rec_in = tensor.concatenate((latitude, longitude), axis=2) 57 58 last_id = tensor.cast(latitude_mask.sum(axis=0) - 1, dtype='int64') 59 60 path = self.rec.apply(merge(self.fwd_fork.apply(rec_in, as_dict=True), 61 {'mask': latitude_mask}), 62 merge(self.bkwd_fork.apply(rec_in, as_dict=True), 63 {'mask': latitude_mask}))[0] 64 65 path_representation = (path[0][:, -self.config.hidden_state_dim:], 66 path[last_id - 1, tensor.arange(latitude_mask.shape[1])] 67 [:, :self.config.hidden_state_dim]) 68 69 embeddings = tuple(self.context_embedder.apply( 70 **{k: kwargs[k] for k in self.context_embedder.inputs })) 71 72 inputs = tensor.concatenate(path_representation + embeddings, axis=1) 73 outputs = self.rec_to_output.apply(inputs) 74 75 return self.process_outputs(outputs) 76 77 @predict.property('inputs') 78 def predict_inputs(self): 79 return self.inputs 80 81 @application(outputs=['cost']) 82 def cost(self, **kwargs): 83 y_hat = self.predict(**kwargs) 84 y = tensor.concatenate((kwargs['destination_latitude'][:, None], 85 kwargs['destination_longitude'][:, None]), axis=1) 86 87 return error.erdist(y_hat, y).mean() 88 89 @cost.property('inputs') 90 def cost_inputs(self): 91 return self.inputs + ['destination_latitude', 'destination_longitude'] 92 93 94 95 class Stream(object): 96 def __init__(self, config): 97 self.config = config 98 99 def train(self, req_vars): 100 stream = TaxiDataset('train', data.traintest_ds) 101 102 if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: 103 stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) 104 else: 105 stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) 106 107 if not data.tvt: 108 valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) 109 valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] 110 stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) 111 112 if hasattr(self.config, 'max_splits'): 113 stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) 114 elif not data.tvt: 115 stream = transformers.add_destination(stream) 116 117 if hasattr(self.config, 'train_max_len'): 118 idx = stream.sources.index('latitude') 119 def max_len_filter(x): 120 return len(x[idx]) <= self.config.train_max_len 121 stream = Filter(stream, max_len_filter) 122 123 stream = transformers.TaxiExcludeEmptyTrips(stream) 124 125 stream = transformers.window(stream, config.window_size) 126 127 stream = transformers.taxi_add_datetime(stream) 128 stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) 129 130 stream = transformers.balanced_batch(stream, key='latitude', 131 batch_size=self.config.batch_size, 132 batch_sort_size=self.config.batch_sort_size) 133 stream = Padding(stream, mask_sources=['latitude', 'longitude']) 134 stream = transformers.Select(stream, req_vars) 135 stream = MultiProcessing(stream) 136 137 return stream 138 139 def valid(self, req_vars): 140 stream = TaxiStream(data.valid_set, data.valid_ds) 141 142 stream = transformers.window(stream, config.window_size) 143 144 stream = transformers.taxi_add_datetime(stream) 145 stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) 146 147 stream = transformers.balanced_batch(stream, key='latitude', 148 batch_size=self.config.batch_size, 149 batch_sort_size=self.config.batch_sort_size) 150 stream = Padding(stream, mask_sources=['latitude', 'longitude']) 151 stream = transformers.Select(stream, req_vars) 152 stream = MultiProcessing(stream) 153 154 return stream 155 156 def test(self, req_vars): 157 stream = TaxiStream('test', data.traintest_ds) 158 159 stream = transformers.window(stream, config.window_size) 160 161 stream = transformers.taxi_add_datetime(stream) 162 stream = transformers.taxi_remove_test_only_clients(stream) 163 164 stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) 165 166 stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) 167 stream = Padding(stream, mask_sources=['latitude', 'longitude']) 168 stream = transformers.Select(stream, req_vars) 169 return stream 170 171 def inputs(self): 172 return {'call_type': tensor.bvector('call_type'), 173 'origin_call': tensor.ivector('origin_call'), 174 'origin_stand': tensor.bvector('origin_stand'), 175 'taxi_id': tensor.wvector('taxi_id'), 176 'timestamp': tensor.ivector('timestamp'), 177 'day_type': tensor.bvector('day_type'), 178 'missing_data': tensor.bvector('missing_data'), 179 'latitude': tensor.tensor('latitude'), 180 'longitude': tensor.tensor('longitude'), 181 'latitude_mask': tensor.matrix('latitude_mask'), 182 'longitude_mask': tensor.matrix('longitude_mask'), 183 'destination_latitude': tensor.vector('destination_latitude'), 184 'destination_longitude': tensor.vector('destination_longitude'), 185 'travel_time': tensor.ivector('travel_time'), 186 'input_time': tensor.ivector('input_time'), 187 'week_of_year': tensor.bvector('week_of_year'), 188 'day_of_week': tensor.bvector('day_of_week'), 189 'qhour_of_day': tensor.bvector('qhour_of_day')} 190