bidirectional.py (4708B)
1 from theano import tensor 2 3 from toolz import merge 4 5 from blocks.bricks import application, MLP, Initializable, Linear, Rectifier, Identity 6 from blocks.bricks.base import lazy 7 from blocks.bricks.recurrent import Bidirectional, LSTM 8 from blocks.utils import shared_floatx_zeros 9 from blocks.bricks.parallel import Fork 10 11 from model import ContextEmbedder 12 import error 13 14 import data 15 16 from model.stream import StreamRec as Stream 17 18 class SegregatedBidirectional(Bidirectional): 19 @application 20 def apply(self, forward_dict, backward_dict): 21 """Applies forward and backward networks and concatenates outputs.""" 22 23 forward = self.children[0].apply(as_list=True, **forward_dict) 24 backward = [x[::-1] for x in 25 26 self.children[1].apply(reverse=True, as_list=True, 27 **backward_dict)] 28 29 return [tensor.concatenate([f, b], axis=2) 30 for f, b in zip(forward, backward)] 31 32 class BidiRNN(Initializable): 33 @lazy() 34 def __init__(self, config, output_dim=2, **kwargs): 35 super(BidiRNN, self).__init__(**kwargs) 36 self.config = config 37 38 self.context_embedder = ContextEmbedder(config) 39 40 act = config.rec_activation() if hasattr(config, 'rec_activation') else None 41 self.rec = SegregatedBidirectional(LSTM(dim=config.hidden_state_dim, activation=act, name='recurrent')) 42 43 self.fwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'], 44 prototype=Linear(), name='fwd_fork') 45 self.bkwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'], 46 prototype=Linear(), name='bkwd_fork') 47 48 rto_in = config.hidden_state_dim * 2 + sum(x[2] for x in config.dim_embeddings) 49 self.rec_to_output = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], 50 dims=[rto_in] + config.dim_hidden + [output_dim]) 51 52 self.sequences = ['latitude', 'latitude_mask', 'longitude'] 53 self.inputs = self.sequences + self.context_embedder.inputs 54 55 self.children = [ self.context_embedder, self.fwd_fork, self.bkwd_fork, 56 self.rec, self.rec_to_output ] 57 58 def _push_allocation_config(self): 59 for i, fork in enumerate([self.fwd_fork, self.bkwd_fork]): 60 fork.input_dim = 2 61 fork.output_dims = [ self.rec.children[i].get_dim(name) 62 for name in fork.output_names ] 63 64 def _push_initialization_config(self): 65 for brick in [self.fwd_fork, self.bkwd_fork, self.rec, self.rec_to_output]: 66 brick.weights_init = self.config.weights_init 67 brick.biases_init = self.config.biases_init 68 69 def process_outputs(self, outputs): 70 pass # must be implemented in child class 71 72 @application(outputs=['destination']) 73 def predict(self, latitude, longitude, latitude_mask, **kwargs): 74 latitude = (latitude.T - data.train_gps_mean[0]) / data.train_gps_std[0] 75 longitude = (longitude.T - data.train_gps_mean[1]) / data.train_gps_std[1] 76 latitude_mask = latitude_mask.T 77 78 rec_in = tensor.concatenate((latitude[:, :, None], longitude[:, :, None]), axis=2) 79 80 last_id = tensor.cast(latitude_mask.sum(axis=0) - 1, dtype='int64') 81 82 path = self.rec.apply(merge(self.fwd_fork.apply(rec_in, as_dict=True), 83 {'mask': latitude_mask}), 84 merge(self.bkwd_fork.apply(rec_in, as_dict=True), 85 {'mask': latitude_mask}))[0] 86 87 path_representation = (path[0][:, -self.config.hidden_state_dim:], 88 path[last_id - 1, tensor.arange(latitude_mask.shape[1])] 89 [:, :self.config.hidden_state_dim]) 90 91 embeddings = tuple(self.context_embedder.apply( 92 **{k: kwargs[k] for k in self.context_embedder.inputs })) 93 94 inputs = tensor.concatenate(path_representation + embeddings, axis=1) 95 outputs = self.rec_to_output.apply(inputs) 96 97 return self.process_outputs(outputs) 98 99 @predict.property('inputs') 100 def predict_inputs(self): 101 return self.inputs 102 103 @application(outputs=['cost']) 104 def cost(self, **kwargs): 105 y_hat = self.predict(**kwargs) 106 y = tensor.concatenate((kwargs['destination_latitude'][:, None], 107 kwargs['destination_longitude'][:, None]), axis=1) 108 109 return error.erdist(y_hat, y).mean() 110 111 @cost.property('inputs') 112 def cost_inputs(self): 113 return self.inputs + ['destination_latitude', 'destination_longitude'] 114 115