rnn_tgtcls_1.py (954B)
1 import os 2 import cPickle 3 4 from blocks.initialization import IsotropicGaussian, Constant 5 6 import data 7 from model.rnn_tgtcls import Model, Stream 8 9 class EmbedderConfig(object): 10 __slots__ = ('dim_embeddings', 'embed_weights_init') 11 12 pre_embedder = EmbedderConfig() 13 pre_embedder.embed_weights_init = IsotropicGaussian(0.001) 14 pre_embedder.dim_embeddings = [ 15 ('week_of_year', 52, 10), 16 ('day_of_week', 7, 10), 17 ('qhour_of_day', 24 * 4, 10), 18 ('day_type', 3, 10), 19 ('taxi_id', 448, 10), 20 ] 21 22 post_embedder = EmbedderConfig() 23 post_embedder.embed_weights_init = IsotropicGaussian(0.001) 24 post_embedder.dim_embeddings = [ 25 ('origin_call', data.origin_call_train_size, 10), 26 ('origin_stand', data.stands_size, 10), 27 ] 28 29 with open(os.path.join(data.path, 'arrival-clusters.pkl')) as f: tgtcls = cPickle.load(f) 30 31 hidden_state_dim = 100 32 weights_init = IsotropicGaussian(0.01) 33 biases_init = Constant(0.001) 34 35 batch_size = 10 36 batch_sort_size = 10