rnn.py (5883B)
1 from theano import tensor 2 3 from blocks.bricks import application, MLP, Initializable, Tanh 4 from blocks.bricks.base import lazy 5 from blocks.bricks.recurrent import LSTM, recurrent 6 from blocks.utils import shared_floatx_zeros 7 8 from fuel.transformers import Batch, Padding 9 from fuel.streams import DataStream 10 from fuel.schemes import ConstantScheme, ShuffledExampleScheme 11 12 from model import ContextEmbedder 13 import data 14 from data import transformers 15 from data.hdf5 import TaxiDataset, TaxiStream 16 import error 17 18 19 from model.stream import StreamRec as Stream 20 21 class RNN(Initializable): 22 @lazy() 23 def __init__(self, config, rec_input_len=2, output_dim=2, **kwargs): 24 super(RNN, self).__init__(**kwargs) 25 self.config = config 26 27 self.pre_context_embedder = ContextEmbedder(config.pre_embedder, name='pre_context_embedder') 28 self.post_context_embedder = ContextEmbedder(config.post_embedder, name='post_context_embedder') 29 30 in1 = rec_input_len + sum(x[2] for x in config.pre_embedder.dim_embeddings) 31 self.input_to_rec = MLP(activations=[Tanh()], dims=[in1, config.hidden_state_dim], name='input_to_rec') 32 33 self.rec = LSTM( 34 dim = config.hidden_state_dim, 35 name = 'recurrent' 36 ) 37 38 in2 = config.hidden_state_dim + sum(x[2] for x in config.post_embedder.dim_embeddings) 39 self.rec_to_output = MLP(activations=[Tanh()], dims=[in2, output_dim], name='rec_to_output') 40 41 self.sequences = ['latitude', 'latitude_mask', 'longitude'] 42 self.context = self.pre_context_embedder.inputs + self.post_context_embedder.inputs 43 self.inputs = self.sequences + self.context 44 self.children = [ self.pre_context_embedder, self.post_context_embedder, self.input_to_rec, self.rec, self.rec_to_output ] 45 46 self.initial_state_ = shared_floatx_zeros((config.hidden_state_dim,), 47 name="initial_state") 48 self.initial_cells = shared_floatx_zeros((config.hidden_state_dim,), 49 name="initial_cells") 50 51 def _push_initialization_config(self): 52 for mlp in [self.input_to_rec, self.rec_to_output]: 53 mlp.weights_init = self.config.weights_init 54 mlp.biases_init = self.config.biases_init 55 self.rec.weights_init = self.config.weights_init 56 57 def get_dim(self, name): 58 return self.rec.get_dim(name) 59 60 def process_rto(self, rto): 61 return rto 62 63 def rec_input(self, latitude, longitude, **kwargs): 64 return (tensor.shape_padright(latitude), tensor.shape_padright(longitude)) 65 66 @recurrent(states=['states', 'cells'], outputs=['destination', 'states', 'cells']) 67 def predict_all(self, **kwargs): 68 pre_emb = tuple(self.pre_context_embedder.apply(**kwargs)) 69 70 itr_in = tensor.concatenate(pre_emb + self.rec_input(**kwargs), axis=1) 71 itr = self.input_to_rec.apply(itr_in) 72 itr = itr.repeat(4, axis=1) 73 (next_states, next_cells) = self.rec.apply(itr, kwargs['states'], kwargs['cells'], mask=kwargs['latitude_mask'], iterate=False) 74 75 post_emb = tuple(self.post_context_embedder.apply(**kwargs)) 76 rto = self.rec_to_output.apply(tensor.concatenate(post_emb + (next_states,), axis=1)) 77 78 rto = self.process_rto(rto) 79 return (rto, next_states, next_cells) 80 81 @predict_all.property('sequences') 82 def predict_all_sequences(self): 83 return self.sequences 84 85 @application(outputs=predict_all.states) 86 def initial_states(self, *args, **kwargs): 87 return self.rec.initial_states(*args, **kwargs) 88 89 @predict_all.property('contexts') 90 def predict_all_context(self): 91 return self.context 92 93 def before_predict_all(self, kwargs): 94 kwargs['latitude'] = (kwargs['latitude'].T - data.train_gps_mean[0]) / data.train_gps_std[0] 95 kwargs['longitude'] = (kwargs['longitude'].T - data.train_gps_mean[1]) / data.train_gps_std[1] 96 kwargs['latitude_mask'] = kwargs['latitude_mask'].T 97 98 @application(outputs=['destination']) 99 def predict(self, **kwargs): 100 self.before_predict_all(kwargs) 101 res = self.predict_all(**kwargs)[0] 102 103 last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=0) - 1, dtype='int64') 104 return res[last_id, tensor.arange(kwargs['latitude_mask'].shape[1])] 105 106 @predict.property('inputs') 107 def predict_inputs(self): 108 return self.inputs 109 110 @application(outputs=['cost_matrix']) 111 def cost_matrix(self, **kwargs): 112 self.before_predict_all(kwargs) 113 114 res = self.predict_all(**kwargs)[0] 115 target = tensor.concatenate( 116 (kwargs['destination_latitude'].dimshuffle('x', 0, 'x'), 117 kwargs['destination_longitude'].dimshuffle('x', 0, 'x')), 118 axis=2) 119 target = target.repeat(kwargs['latitude'].shape[0], axis=0) 120 ce = error.erdist(target.reshape((-1, 2)), res.reshape((-1, 2))) 121 ce = ce.reshape(kwargs['latitude'].shape) 122 return ce * kwargs['latitude_mask'] 123 124 @cost_matrix.property('inputs') 125 def cost_matrix_inputs(self): 126 return self.inputs + ['destination_latitude', 'destination_longitude'] 127 128 @application(outputs=['cost']) 129 def cost(self, latitude_mask, **kwargs): 130 return self.cost_matrix(latitude_mask=latitude_mask, **kwargs).sum() / latitude_mask.sum() 131 132 @cost.property('inputs') 133 def cost_inputs(self): 134 return self.inputs + ['destination_latitude', 'destination_longitude'] 135 136 @application(outputs=['cost']) 137 def valid_cost(self, **kwargs): 138 last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=1) - 1, dtype='int64') 139 return self.cost_matrix(**kwargs)[last_id, tensor.arange(kwargs['latitude_mask'].shape[0])].mean() 140 141 @valid_cost.property('inputs') 142 def valid_cost_inputs(self): 143 return self.inputs + ['destination_latitude', 'destination_longitude'] 144 145