commit 29a37fe435f216b21cb684071597a32779abe78a
parent 95ff4ae00954ae9a0ffb41297b2f65da251f183b
Author: Étienne Simon <esimon@esimon.eu>
Date:   Fri, 13 Jun 2014 13:59:13 +0200
Datalog rank distribution
Diffstat:
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/model.py b/model.py
@@ -149,14 +149,15 @@ class Model(object):
         mean = numpy.mean(result)
         std = numpy.std(result)
         top10 = numpy.mean(map(lambda x: x<=10, result))
-        return (mean, std, top10)
+        return (mean, std, top10, result)
 
     def validate(self):
         """ Validate the model. """
         log('Validation model "{0}" epoch {1:<5}: begin\n'.format(self.tag, self.epoch))
-        (valid_mean, valid_std, valid_top10) = self.error('valid')
+        (valid_mean, valid_std, valid_top10, valid_distribution) = self.error('valid')
         log('Validation model "{0}" epoch {1:<5}: mean: {2:<15} valid std: {3:<15} valid top10: {4:<15}\n'.format(self.tag, self.epoch, valid_mean, valid_std, valid_top10))
-        datalog(self.config['datalog path']+'/'+self.config['model name'], self.epoch, valid_mean, valid_std, valid_top10)
+        datalog(self.config['datalog path']+'/'+self.config['model name'], 'summary', self.epoch, valid_mean, valid_std, valid_top10)
+        datalog(self.config['datalog path']+'/'+self.config['model name'], 'distribution', self.epoch, valid_distribution)
         if not hasattr(self, 'best_mean') or valid_mean < self.best_mean:
             self.best_mean = valid_mean
             if self.config['save best model']:
@@ -165,7 +166,7 @@ class Model(object):
                 log('Validation model "{0}" epoch {1:<5}: saved\n'.format(self.tag, self.epoch))
 
         if self.config['validate on training data']:
-            (train_mean, train_std, train_top10) = self.error('train')
+            (train_mean, train_std, train_top10, _) = self.error('train')
             log('Validation model "{0}" epoch {1:<5} train mean: {2:<15} std: {3:<15} train top10: {4:<15}\n'.format(self.tag, self.epochtrain_mean, train_std, train_top10))
 
     def test(self, save=True):
@@ -173,9 +174,10 @@ class Model(object):
         if save:
             log('# Test model "{0}": begin\n'.format(self.tag))
 
-        (mean, std, top10) = self.error('test')
+        (mean, std, top10, distribution) = self.error('test')
         log('# Test model "{0}": mean: {1:<15} std: {2:<15} top10: {3:<15}\n'.format(self.tag, mean, std, top10))
-        datalog(self.config['datalog path']+'/'+self.config['model name'], 'test', mean, std, top10)
+        datalog(self.config['datalog path']+'/'+self.config['model name'], 'summary', 'test', mean, std, top10)
+        datalog(self.config['datalog path']+'/'+self.config['model name'], 'distribution', 'test', distribution)
 
         if save:
             log('# Test model "{0}": saving...\n'.format(self.tag))