memory_network.py (14517B)
1 2 from theano import tensor 3 4 from fuel.transformers import Batch, MultiProcessing, Merge, Padding 5 from fuel.streams import DataStream 6 from fuel.schemes import ConstantScheme, ShuffledExampleScheme, SequentialExampleScheme 7 from blocks.bricks import application, MLP, Rectifier, Initializable, Softmax 8 9 import data 10 from data import transformers 11 from data.cut import TaxiTimeCutScheme 12 from data.hdf5 import TaxiDataset, TaxiStream 13 import error 14 from model import ContextEmbedder 15 16 class MemoryNetworkBase(Initializable): 17 def __init__(self, config, prefix_encoder, candidate_encoder, **kwargs): 18 super(MemoryNetworkBase, self).__init__(**kwargs) 19 20 self.prefix_encoder = prefix_encoder 21 self.candidate_encoder = candidate_encoder 22 self.config = config 23 24 self.softmax = Softmax() 25 self.children = [ self.softmax, prefix_encoder, candidate_encoder ] 26 27 self.inputs = self.prefix_encoder.apply.inputs \ 28 + ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] \ 29 + ['candidate_destination_latitude', 'candidate_destination_longitude'] 30 31 def candidate_destination(self, **kwargs): 32 return tensor.concatenate( 33 (tensor.shape_padright(kwargs['candidate_destination_latitude']), 34 tensor.shape_padright(kwargs['candidate_destination_longitude'])), 35 axis=1) 36 37 @application(outputs=['cost']) 38 def cost(self, **kwargs): 39 y_hat = self.predict(**kwargs) 40 y = tensor.concatenate((kwargs['destination_latitude'][:, None], 41 kwargs['destination_longitude'][:, None]), axis=1) 42 43 return error.erdist(y_hat, y).mean() 44 45 @application(outputs=['destination']) 46 def predict(self, **kwargs): 47 prefix_representation = self.prefix_encoder.apply(**{ x: kwargs[x] for x in self.prefix_encoder.apply.inputs }) 48 candidate_representation = self.candidate_encoder.apply(**{ x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs }) 49 50 if self.config.normalize_representation: 51 prefix_representation = prefix_representation \ 52 / tensor.sqrt((prefix_representation ** 2).sum(axis=1, keepdims=True)) 53 candidate_representation = candidate_representation \ 54 / tensor.sqrt((candidate_representation ** 2).sum(axis=1, keepdims=True)) 55 56 similarity_score = tensor.dot(prefix_representation, candidate_representation.T) 57 similarity = self.softmax.apply(similarity_score) 58 59 return tensor.dot(similarity, self.candidate_destination(**kwargs)) 60 61 @predict.property('inputs') 62 def predict_inputs(self): 63 return self.inputs 64 65 @cost.property('inputs') 66 def cost_inputs(self): 67 return self.inputs + ['destination_latitude', 'destination_longitude'] 68 69 class StreamBase(object): 70 def __init__(self, config): 71 self.config = config 72 73 self.prefix_inputs = [ 74 ('call_type', tensor.bvector), 75 ('origin_call', tensor.ivector), 76 ('origin_stand', tensor.bvector), 77 ('taxi_id', tensor.wvector), 78 ('timestamp', tensor.ivector), 79 ('day_type', tensor.bvector), 80 ('missing_data', tensor.bvector), 81 ('latitude', tensor.matrix), 82 ('longitude', tensor.matrix), 83 ('destination_latitude', tensor.vector), 84 ('destination_longitude', tensor.vector), 85 ('travel_time', tensor.ivector), 86 ('input_time', tensor.ivector), 87 ('week_of_year', tensor.bvector), 88 ('day_of_week', tensor.bvector), 89 ('qhour_of_day', tensor.bvector) 90 ] 91 self.candidate_inputs = self.prefix_inputs 92 93 def inputs(self): 94 prefix_inputs = { name: constructor(name) 95 for name, constructor in self.prefix_inputs } 96 candidate_inputs = { 'candidate_'+name: constructor('candidate_'+name) 97 for name, constructor in self.candidate_inputs } 98 return dict(prefix_inputs.items() + candidate_inputs.items()) 99 100 @property 101 def valid_dataset(self): 102 return TaxiDataset(data.valid_set, data.valid_ds) 103 104 @property 105 def valid_trips_ids(self): 106 valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) 107 return valid.get_data(None, slice(0, valid.num_examples))[0] 108 109 @property 110 def train_dataset(self): 111 return TaxiDataset('train', data.traintest_ds) 112 113 @property 114 def test_dataset(self): 115 return TaxiDataset('test', data.traintest_ds) 116 117 118 class StreamSimple(StreamBase): 119 def __init__(self, config): 120 super(StreamSimple, self).__init__(config) 121 122 self.prefix_inputs += [ 123 ('first_k_latitude', tensor.matrix), 124 ('first_k_longitude', tensor.matrix), 125 ('last_k_latitude', tensor.matrix), 126 ('last_k_longitude', tensor.matrix), 127 ] 128 self.candidate_inputs = self.prefix_inputs 129 130 def candidate_stream(self, n_candidates): 131 candidate_stream = DataStream(self.train_dataset, 132 iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) 133 if not data.tvt: 134 candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids) 135 candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream) 136 candidate_stream = transformers.taxi_add_datetime(candidate_stream) 137 candidate_stream = transformers.taxi_add_first_last_len(candidate_stream, 138 self.config.n_begin_end_pts) 139 if not data.tvt: 140 candidate_stream = transformers.add_destination(candidate_stream) 141 142 return Batch(candidate_stream, 143 iteration_scheme=ConstantScheme(n_candidates)) 144 145 def train(self, req_vars): 146 prefix_stream = DataStream(self.train_dataset, 147 iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) 148 149 if not data.tvt: 150 prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) 151 prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) 152 prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, 153 max_splits=self.config.max_splits) 154 prefix_stream = transformers.taxi_add_datetime(prefix_stream) 155 prefix_stream = transformers.taxi_add_first_last_len(prefix_stream, 156 self.config.n_begin_end_pts) 157 prefix_stream = Batch(prefix_stream, 158 iteration_scheme=ConstantScheme(self.config.batch_size)) 159 160 candidate_stream = self.candidate_stream(self.config.train_candidate_size) 161 162 sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources) 163 stream = Merge((prefix_stream, candidate_stream), sources) 164 stream = transformers.Select(stream, tuple(req_vars)) 165 stream = MultiProcessing(stream) 166 return stream 167 168 def valid(self, req_vars): 169 prefix_stream = DataStream( 170 self.valid_dataset, 171 iteration_scheme=SequentialExampleScheme(self.valid_dataset.num_examples)) 172 prefix_stream = transformers.taxi_add_datetime(prefix_stream) 173 prefix_stream = transformers.taxi_add_first_last_len(prefix_stream, 174 self.config.n_begin_end_pts) 175 prefix_stream = Batch(prefix_stream, 176 iteration_scheme=ConstantScheme(self.config.batch_size)) 177 178 candidate_stream = self.candidate_stream(self.config.valid_candidate_size) 179 180 sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources) 181 stream = Merge((prefix_stream, candidate_stream), sources) 182 stream = transformers.Select(stream, tuple(req_vars)) 183 stream = MultiProcessing(stream) 184 return stream 185 186 def test(self, req_vars): 187 prefix_stream = DataStream( 188 self.test_dataset, 189 iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples)) 190 prefix_stream = transformers.taxi_add_datetime(prefix_stream) 191 prefix_stream = transformers.taxi_add_first_last_len(prefix_stream, 192 self.config.n_begin_end_pts) 193 194 if not data.tvt: 195 prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream) 196 197 prefix_stream = Batch(prefix_stream, 198 iteration_scheme=ConstantScheme(self.config.batch_size)) 199 200 candidate_stream = self.candidate_stream(self.config.test_candidate_size) 201 202 sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources) 203 stream = Merge((prefix_stream, candidate_stream), sources) 204 stream = transformers.Select(stream, tuple(req_vars)) 205 stream = MultiProcessing(stream) 206 return stream 207 208 class StreamRecurrent(StreamBase): 209 def __init__(self, config): 210 super(StreamRecurrent, self).__init__(config) 211 212 self.prefix_inputs += [ 213 ('latitude_mask', tensor.matrix), 214 ('longitude_mask', tensor.matrix), 215 ] 216 self.candidate_inputs = self.prefix_inputs 217 218 def candidate_stream(self, n_candidates, sortmap=True): 219 candidate_stream = DataStream(self.train_dataset, 220 iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) 221 if not data.tvt: 222 candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids) 223 candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream) 224 candidate_stream = transformers.taxi_add_datetime(candidate_stream) 225 226 if not data.tvt: 227 candidate_stream = transformers.add_destination(candidate_stream) 228 229 if sortmap: 230 candidate_stream = transformers.balanced_batch(candidate_stream, 231 key='latitude', 232 batch_size=n_candidates, 233 batch_sort_size=self.config.batch_sort_size) 234 else: 235 candidate_stream = Batch(candidate_stream, 236 iteration_scheme=ConstantScheme(n_candidates)) 237 238 candidate_stream = Padding(candidate_stream, 239 mask_sources=['latitude', 'longitude']) 240 241 return candidate_stream 242 243 def train(self, req_vars): 244 prefix_stream = DataStream(self.train_dataset, 245 iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples)) 246 247 if not data.tvt: 248 prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids) 249 prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) 250 prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, 251 max_splits=self.config.max_splits) 252 253 prefix_stream = transformers.taxi_add_datetime(prefix_stream) 254 255 prefix_stream = transformers.balanced_batch(prefix_stream, 256 key='latitude', 257 batch_size=self.config.batch_size, 258 batch_sort_size=self.config.batch_sort_size) 259 260 prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude']) 261 262 candidate_stream = self.candidate_stream(self.config.train_candidate_size) 263 264 sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources) 265 stream = Merge((prefix_stream, candidate_stream), sources) 266 267 stream = transformers.Select(stream, tuple(req_vars)) 268 # stream = MultiProcessing(stream) 269 return stream 270 271 def valid(self, req_vars): 272 prefix_stream = DataStream( 273 self.valid_dataset, 274 iteration_scheme=SequentialExampleScheme(self.valid_dataset.num_examples)) 275 276 #prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream) 277 278 prefix_stream = transformers.taxi_add_datetime(prefix_stream) 279 280 prefix_stream = transformers.balanced_batch(prefix_stream, 281 key='latitude', 282 batch_size=self.config.batch_size, 283 batch_sort_size=self.config.batch_sort_size) 284 285 prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude']) 286 287 candidate_stream = self.candidate_stream(self.config.valid_candidate_size) 288 289 sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources) 290 stream = Merge((prefix_stream, candidate_stream), sources) 291 292 stream = transformers.Select(stream, tuple(req_vars)) 293 # stream = MultiProcessing(stream) 294 295 return stream 296 297 def test(self, req_vars): 298 prefix_stream = DataStream( 299 self.test_dataset, 300 iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples)) 301 302 prefix_stream = transformers.taxi_add_datetime(prefix_stream) 303 if not data.tvt: 304 prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream) 305 306 prefix_stream = Batch(prefix_stream, 307 iteration_scheme=ConstantScheme(self.config.batch_size)) 308 prefix_stream = Padding(prefix_stream, mask_sources=['latitude', 'longitude']) 309 310 candidate_stream = self.candidate_stream(self.config.test_candidate_size, False) 311 312 sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources) 313 stream = Merge((prefix_stream, candidate_stream), sources) 314 315 stream = transformers.Select(stream, tuple(req_vars)) 316 # stream = MultiProcessing(stream) 317 318 return stream