commit 57fe795d14e70c06c9bdbe6fe903588b5f75474e parent 448e848796757ad9f0a2f681886f868b8f22e81f Author: Alex Auvolat <alex.auvolat@ens.fr> Date: Fri, 22 May 2015 15:51:26 -0400 Add parametrizability for how the training data is presented Diffstat:
18 files changed, 29 insertions(+), 5 deletions(-)
diff --git a/config/dest_simple_mlp_2_cs.py b/config/dest_simple_mlp_2_cs.py @@ -26,3 +26,4 @@ momentum = 0.99 batch_size = 32 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/dest_simple_mlp_2_cswdt.py b/config/dest_simple_mlp_2_cswdt.py @@ -30,3 +30,4 @@ momentum = 0.99 batch_size = 32 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/dest_simple_mlp_2_noembed.py b/config/dest_simple_mlp_2_noembed.py @@ -23,3 +23,4 @@ momentum = 0.99 batch_size = 32 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/dest_simple_mlp_tgtcls_0_cs.py b/config/dest_simple_mlp_tgtcls_0_cs.py @@ -31,3 +31,4 @@ momentum = 0.99 batch_size = 32 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/dest_simple_mlp_tgtcls_1_cs.py b/config/dest_simple_mlp_tgtcls_1_cs.py @@ -31,3 +31,4 @@ momentum = 0.99 batch_size = 32 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/dest_simple_mlp_tgtcls_1_cswdt.py b/config/dest_simple_mlp_tgtcls_1_cswdt.py @@ -35,3 +35,4 @@ momentum = 0.99 batch_size = 32 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/dest_simple_mlp_tgtcls_1_cswdtx.py b/config/dest_simple_mlp_tgtcls_1_cswdtx.py @@ -33,6 +33,10 @@ mlp_biases_init = Constant(0.001) learning_rate = 0.0001 momentum = 0.99 -batch_size = 32 +batch_size = 100 + +use_cuts_for_training = True +max_splits = 1 valid_set = 'cuts/test_times_0' + diff --git a/config/dest_simple_mlp_tgtcls_1_cswdtx_alexandre.py b/config/dest_simple_mlp_tgtcls_1_cswdtx_alexandre.py @@ -36,3 +36,4 @@ momentum = 0.9 batch_size = 200 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/joint_simple_mlp_tgtcls_111_cswdtx.py b/config/joint_simple_mlp_tgtcls_111_cswdtx.py @@ -53,3 +53,4 @@ momentum = 0.99 batch_size = 200 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger.py b/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger.py @@ -54,3 +54,4 @@ batch_size = 200 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger_dropout.py b/config/joint_simple_mlp_tgtcls_111_cswdtx_bigger_dropout.py @@ -57,3 +57,4 @@ batch_size = 200 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/joint_simple_mlp_tgtcls_111_cswdtx_noise_dout.py b/config/joint_simple_mlp_tgtcls_111_cswdtx_noise_dout.py @@ -60,3 +60,4 @@ noise = 0.01 noise_inputs = VariableFilter(roles=[roles.PARAMETER]) valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/joint_simple_mlp_tgtcls_1_cswdtx.py b/config/joint_simple_mlp_tgtcls_1_cswdtx.py @@ -53,3 +53,4 @@ momentum = 0.99 batch_size = 200 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/joint_simple_mlp_tgtcls_1_cswdtx_bigger.py b/config/joint_simple_mlp_tgtcls_1_cswdtx_bigger.py @@ -53,3 +53,4 @@ momentum = 0.99 batch_size = 200 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/time_simple_mlp_1.py b/config/time_simple_mlp_1.py @@ -26,3 +26,4 @@ momentum = 0.99 batch_size = 32 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/time_simple_mlp_2_cswdtx.py b/config/time_simple_mlp_2_cswdtx.py @@ -33,3 +33,4 @@ momentum = 0.99 batch_size = 32 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/config/time_simple_mlp_tgtcls_2_cswdtx.py b/config/time_simple_mlp_tgtcls_2_cswdtx.py @@ -36,3 +36,4 @@ momentum = 0.99 batch_size = 32 valid_set = 'cuts/test_times_0' +max_splits = 100 diff --git a/model/mlp.py b/model/mlp.py @@ -51,14 +51,18 @@ class Stream(object): self.config = config def train(self, req_vars): - stream = TaxiDataset('train') - stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) - valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',)) valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] + stream = TaxiDataset('train') + + if hasattr(self.config, 'use_cuts_for_trainig') and self.config.use_cuts_for_training: + stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) + else: + stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) + stream = transformers.TaxiExcludeTrips(valid_trips_ids, stream) - stream = transformers.TaxiGenerateSplits(stream, max_splits=1) + stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) stream = transformers.TaxiAddDateTime(stream) stream = transformers.TaxiAddFirstLastLen(self.config.n_begin_end_pts, stream)