commit 5096e0cdae167122d07b09cd207a04f28ea5c3f5
parent 98139f573eb179c8f5a06ba6c8d8883376814ccf
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date: Thu, 2 Jul 2015 13:23:28 -0400
Add random seed for TaxiGenerateSplits and for fuel
Diffstat:
2 files changed, 17 insertions(+), 9 deletions(-)
diff --git a/data/transformers.py b/data/transformers.py
@@ -1,8 +1,10 @@
import datetime
-import random
import numpy
import theano
+
+import fuel
+
from fuel.schemes import ConstantScheme
from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack
@@ -66,13 +68,15 @@ class TaxiGenerateSplits(Transformer):
self.id_latitude = data_stream.sources.index('latitude')
self.id_longitude = data_stream.sources.index('longitude')
+ self.rng = numpy.random.RandomState(fuel.config.default_seed)
+
def get_data(self, request=None):
if request is not None:
raise ValueError
while self.isplit >= len(self.splits):
self.data = next(self.child_epoch_iterator)
self.splits = range(len(self.data[self.id_longitude]))
- random.shuffle(self.splits)
+ self.rng.shuffle(self.splits)
if self.max_splits != -1 and len(self.splits) > self.max_splits:
self.splits = self.splits[:self.max_splits]
self.isplit = 0
diff --git a/train.py b/train.py
@@ -9,12 +9,16 @@ from functools import reduce
from theano import tensor
+import blocks
+import fuel
+
from blocks import roles
from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum
from blocks.extensions import Printing, FinishAfter
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-import blocks
+
blocks.config.default_seed = 123
+fuel.config.default_seed = 123
try:
from blocks.extras.extensions.plotting import Plot
@@ -104,12 +108,12 @@ if __name__ == "__main__":
every_n_batches=1000),
Printing(every_n_batches=1000),
- SaveLoadParams(dump_path, cg,
- before_training=True, # before training -> load params
- every_n_batches=1000, # every N batches -> save params
- after_epoch=True, # after epoch -> save params
- after_training=True, # after training -> save params
- ),
+ # SaveLoadParams(dump_path, cg,
+ # before_training=True, # before training -> load params
+ # every_n_batches=1000, # every N batches -> save params
+ # after_epoch=True, # after epoch -> save params
+ # after_training=True, # after training -> save params
+ # ),
RunOnTest(model_name,
model,