commit 6881a2a302c5abc3e2ef4b710fa2033ce83615ea
parent c61b71b63396648f490d9cb10e31de2bcdba601f
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 30 Apr 2014 15:52:35 +0200
Save epoch number
Diffstat:
1 file changed, 9 insertions(+), 6 deletions(-)
diff --git a/model.py b/model.py
@@ -43,11 +43,13 @@ class Model(object):
 
         if filepath is None:
             Relations = config['relations']
+            self.epoch = 0
             self.embeddings = Embeddings(config['rng'], dataset.number_embeddings, config['dimension'], self.tag+'.embeddings')
             self.relations = Relations(config['rng'], dataset.number_relations, config['dimension'], self.tag+'.relations')
         else:
             log('## Loading model from "{0}"\n'.format(filepath))
             with open(filepath, 'rb') as file:
+                self.epoch = cPickle.load(file)
                 self.embeddings = cPickle.load(file)
                 self.relations = cPickle.load(file)
 
@@ -122,13 +124,14 @@ class Model(object):
         validation_frequency = self.config['validation frequency']
         number_epoch = self.config['number of epoch']
 
-        for epoch in xrange(number_epoch):
+        while self.epoch < number_epoch:
             for (relation, left_positive, right_positive, left_negative, right_negative) in self.dataset.training_minibatch(batch_size):
                 self.normalise_function()
                 self.train_function(relation, left_positive, right_positive, left_negative, right_negative)
 
-            if (epoch+1) % validation_frequency == 0:
-                self.validate(epoch+1)
+            self.epoch += 1
+            if self.epoch % validation_frequency == 0:
+                self.validate()
 
     def error(self, name, transform_scores=(lambda x: x)):
         """ Compute the mean rank, standard deviation and top 10 on a given data. """
@@ -147,12 +150,12 @@ class Model(object):
         top10 = numpy.mean(map(lambda x: x<=10, result))
         return (mean, std, top10)
 
-    def validate(self, epoch):
+    def validate(self):
         """ Validate the model. """
-        log('Validation epoch {:<5}'.format(epoch))
+        log('Validation epoch {:<5}'.format(self.epoch))
         (valid_mean, valid_std, valid_top10) = self.error('valid')
         log(' valid mean: {0:<15} valid std: {1:<15} valid top10: {2:<15}'.format(valid_mean, valid_std, valid_top10))
-        datalog(self.config['datalog path']+'/'+self.config['model name'], epoch, valid_mean, valid_std, valid_top10)
+        datalog(self.config['datalog path']+'/'+self.config['model name'], self.epoch, valid_mean, valid_std, valid_top10)
         if not hasattr(self, 'best_mean') or valid_mean < self.best_mean:
             self.best_mean = valid_mean
             log('(best so far')