rnn_lag_tgtcls.py (1253B)
1 import numpy 2 import theano 3 from theano import tensor 4 from blocks.bricks.base import lazy 5 from blocks.bricks import Softmax 6 7 from model.rnn import RNN, Stream 8 9 10 class Model(RNN): 11 @lazy() 12 def __init__(self, config, **kwargs): 13 super(Model, self).__init__(config, rec_input_len=4, output_dim=config.tgtcls.shape[0], **kwargs) 14 self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), name='classes') 15 self.softmax = Softmax() 16 self.sequences.extend(['latitude_lag', 'longitude_lag']) 17 self.children.append(self.softmax) 18 19 def before_predict_all(self, kwargs): 20 super(Model, self).before_predict_all(kwargs) 21 kwargs['latitude_lag'] = tensor.extra_ops.repeat(kwargs['latitude'], 2, axis=0) 22 kwargs['longitude_lag'] = tensor.extra_ops.repeat(kwargs['longitude'], 2, axis=0) 23 24 def process_rto(self, rto): 25 return tensor.dot(self.softmax.apply(rto), self.classes) 26 27 def rec_input(self, latitude, longitude, latitude_lag, longitude_lag, **kwargs): 28 return (tensor.shape_padright(latitude), 29 tensor.shape_padright(longitude), 30 tensor.shape_padright(latitude_lag), 31 tensor.shape_padright(longitude_lag))