commit 051a8e243443b93520a4c6bd506abe1220550802
parent 4be068298edfd7777e6de5572419f9c2bdcbd985
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 16 Apr 2014 18:26:15 +0200
Add and change hyperparameters
Diffstat:
1 file changed, 11 insertions(+), 8 deletions(-)
diff --git a/model.py b/model.py
@@ -93,17 +93,17 @@ class Model(object):
         Keyword arguments:
         cost -- The cost to optimise.
         """
-        lr_relations = self.hyperparameters['relation_learning_rate']
-        lr_embeddings = self.hyperparameters['embeddings_learning_rate']
+        lr_relations = self.hyperparameters['relation learning rate']
+        lr_embeddings = self.hyperparameters['embeddings learning rate']
         return self.relations.updates(cost, lr_relations) + self.embeddings.updates(cost, lr_embeddings)
 
     def train(self):
         """ Train the model. """
         print >>sys.stderr, '# Training the model "{0}"'.format(self.tag)
 
-        batch_size = self.hyperparameters['train_batch_size']
-        validation_frequency = self.hyperparameters['validation_frequency']
-        number_epoch = self.hyperparameters['number_epoch']
+        batch_size = self.hyperparameters['train batch size']
+        validation_frequency = self.hyperparameters['validation frequency']
+        number_epoch = self.hyperparameters['number of epoch']
 
         for epoch in xrange(number_epoch):
             if epoch % validation_frequency == 0:
@@ -115,7 +115,7 @@ class Model(object):
 
     def error(self, name):
         """ Compute the mean rank and top 10 on a given data. """
-        batch_size = self.hyperparameters['test_batch_size']
+        batch_size = self.hyperparameters['test batch size']
         count, mean, top10 = 0, 0, 0
         for (relation, left, right) in self.dataset.iterate(name, batch_size): # TODO Test symmetric
             scores = None
@@ -138,8 +138,11 @@ class Model(object):
         print >>sys.stderr, 'Validation epoch {:<5}'.format(epoch),
         (valid_mean, valid_top10) = self.error('valid')
         print >>sys.stderr, 'valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10),
-        (train_mean, train_top10) = self.error('train')
-        print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10)
+        if self.hyperparameters['validate on training data']:
+            (train_mean, train_top10) = self.error('train')
+            print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10)
+        else
+            print >>sys.stderr, ''
 
     def test(self):
         """ Test the model. """