memory_network_bidir.py (4499B)
1 from theano import tensor 2 3 from toolz import merge 4 5 from blocks.bricks import application, MLP, Rectifier, Initializable, Softmax, Linear 6 from blocks.bricks.parallel import Fork 7 from blocks.bricks.recurrent import Bidirectional, LSTM 8 9 import data 10 from data import transformers 11 from data.cut import TaxiTimeCutScheme 12 from data.hdf5 import TaxiDataset, TaxiStream 13 import error 14 from model import ContextEmbedder 15 16 from memory_network import StreamRecurrent as Stream 17 from memory_network import MemoryNetworkBase 18 from bidirectional import SegregatedBidirectional 19 20 21 class RecurrentEncoder(Initializable): 22 def __init__(self, config, output_dim, activation, **kwargs): 23 super(RecurrentEncoder, self).__init__(**kwargs) 24 25 self.config = config 26 self.context_embedder = ContextEmbedder(config) 27 28 self.rec = SegregatedBidirectional(LSTM(dim=config.rec_state_dim, name='encoder_recurrent')) 29 30 self.fwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'], 31 prototype=Linear(), name='fwd_fork') 32 self.bkwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'], 33 prototype=Linear(), name='bkwd_fork') 34 35 rto_in = config.rec_state_dim * 2 + sum(x[2] for x in config.dim_embeddings) 36 self.rec_to_output = MLP( 37 activations=[Rectifier() for _ in config.dim_hidden] + [activation], 38 dims=[rto_in] + config.dim_hidden + [output_dim], 39 name='encoder_rto') 40 41 self.children = [self.context_embedder, self.rec, self.fwd_fork, self.bkwd_fork, self.rec_to_output] 42 43 self.rec_inputs = ['latitude', 'longitude', 'latitude_mask'] 44 self.inputs = self.context_embedder.inputs + self.rec_inputs 45 46 def _push_allocation_config(self): 47 for i, fork in enumerate([self.fwd_fork, self.bkwd_fork]): 48 fork.input_dim = 2 49 fork.output_dims = [ self.rec.children[i].get_dim(name) 50 for name in fork.output_names ] 51 52 def _push_initialization_config(self): 53 for brick in self.children: 54 brick.weights_init = self.config.weights_init 55 brick.biases_init = self.config.biases_init 56 57 @application 58 def apply(self, latitude, longitude, latitude_mask, **kwargs): 59 latitude = (latitude.T - data.train_gps_mean[0]) / data.train_gps_std[0] 60 longitude = (longitude.T - data.train_gps_mean[1]) / data.train_gps_std[1] 61 latitude_mask = latitude_mask.T 62 63 rec_in = tensor.concatenate((latitude[:, :, None], longitude[:, :, None]), 64 axis=2) 65 path = self.rec.apply(merge(self.fwd_fork.apply(rec_in, as_dict=True), 66 {'mask': latitude_mask}), 67 merge(self.bkwd_fork.apply(rec_in, as_dict=True), 68 {'mask': latitude_mask}))[0] 69 70 last_id = tensor.cast(latitude_mask.sum(axis=0) - 1, dtype='int64') 71 72 path_representation = (path[0][:, -self.config.rec_state_dim:], 73 path[last_id - 1, tensor.arange(last_id.shape[0])] 74 [:, :self.config.rec_state_dim]) 75 76 embeddings = tuple(self.context_embedder.apply( 77 **{k: kwargs[k] for k in self.context_embedder.inputs })) 78 79 inputs = tensor.concatenate(path_representation + embeddings, axis=1) 80 outputs = self.rec_to_output.apply(inputs) 81 82 return outputs 83 84 @apply.property('inputs') 85 def apply_inputs(self): 86 return self.inputs 87 88 89 class Model(MemoryNetworkBase): 90 def __init__(self, config, **kwargs): 91 92 # Build prefix encoder : recurrent then MLP 93 prefix_encoder = RecurrentEncoder(config.prefix_encoder, 94 config.representation_size, 95 config.representation_activation(), 96 name='prefix_encoder') 97 98 # Build candidate encoder 99 candidate_encoder = RecurrentEncoder(config.candidate_encoder, 100 config.representation_size, 101 config.representation_activation(), 102 name='candidate_encoder') 103 104 # And... that's it! 105 super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs)