commit f258bea01d6ea61fa97a4e0d9d81009bbbd354f8
parent 92227335bd42489fe0a25e7670b36d05fb4a519f
Author: Étienne Simon <esimon@esimon.eu>
Date:   Thu, 17 Apr 2014 12:18:07 +0200
Add auto-saving & Fix loading/saving
Diffstat:
2 files changed, 19 insertions(+), 5 deletions(-)
diff --git a/main.py b/main.py
@@ -9,18 +9,22 @@ from relations.translations import *
 
 if __name__ == '__main__':
     if len(sys.argv)<3:
-        print >>sys.stderr, 'Usage: {0} data parameters'.format(sys.argv[0])
+        print >>sys.stderr, 'Usage: {0} data parameters [model]'.format(sys.argv[0])
         sys.exit(1)
     data = sys.argv[1]
     config = sys.argv[2]
+    model_path = None if len(sys.argv)<4 else sys.argv[3]
 
     with open(config, 'r') as config_file:
         hyperparameters = json.load(config_file)
         for k, v in hyperparameters.iteritems():
-            if isinstance(v, basestring):
-                hyperparameters[k] = eval(v)
+            if isinstance(v, basestring) and v.startswith('python:'):
+                hyperparameters[k] = eval(v[7:])
 
     data = Dataset(data)
-    model = Model.initialise(Translations, data, hyperparameters, 'TransE')
+    if model_path is None:
+        model = Model.initialise(Translations, data, hyperparameters, 'TransE')
+    else:
+        model = Model.load(model_path, data, hyperparameters, 'TransE')
     model.train()
     model.test()
diff --git a/model.py b/model.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python2
 
 import sys
+import cPickle
 import numpy
 import scipy
 import theano
@@ -44,13 +45,14 @@ class Model(object):
         return self
 
     @classmethod
-    def load(cls, filepath, dataset, hyperparameters):
+    def load(cls, filepath, dataset, hyperparameters, tag):
         """ Load a model from a file.
 
         Keyword arguments:
         filepath -- path to the Model file
         dataset -- dataset on which the model will be trained and tested
         hyperparameters -- hyperparameters dictionary
+        tag -- name of the embeddings for parameter declaration
         """
         print >>sys.stderr, '# Loading model from "{0}"'.format(filepath)
 
@@ -60,6 +62,7 @@ class Model(object):
             self.relations = cPickle.load(file)
         self.dataset = dataset;
         self.hyperparameters = hyperparameters;
+        self.tag = tag
 
         self.build()
         return self
@@ -153,6 +156,13 @@ class Model(object):
         print >>sys.stderr, 'Validation epoch {:<5}'.format(epoch),
         (valid_mean, valid_top10) = self.error('valid')
         print >>sys.stderr, 'valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10),
+        if not hasattr(self, 'best_mean') or valid_mean > self.best_mean:
+            print >>sys.stderr, ' (best so far',
+            if self.hyperparameters['save best model']:
+                print >>sys.stderr, ' saving',
+                self.save(self.hyperparameters['best model save location'])
+            print >>sys.stderr, ')',
+
         if self.hyperparameters['validate on training data']:
             (train_mean, train_top10) = self.error('train')
             print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10)