commit 5885567c2353298286e3594296e66c10968ef864
parent 6881a2a302c5abc3e2ef4b710fa2033ce83615ea
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 30 Apr 2014 18:33:24 +0200
Add multithreading training
Diffstat:
5 files changed, 53 insertions(+), 28 deletions(-)
diff --git a/dataset.py b/dataset.py
@@ -6,8 +6,7 @@ import numpy
 import theano
 
 class Dataset(object):
-    def __init__(self, prefix, rng):
-        self.rng = rng
+    def __init__(self, prefix):
         log('# Loading dataset "{0}"\n'.format(prefix))
         with open(prefix+'/embeddings', 'r') as file:
             self.embeddings = file.readlines()
@@ -30,11 +29,11 @@ class Dataset(object):
         setattr(self, name+'_relation', scipy.sparse.csr_matrix(([1]*N, relation, range(N+1)), shape=(N, self.number_relations), dtype=theano.config.floatX))
         setattr(self, name+'_left', scipy.sparse.csr_matrix(([1]*N, left, range(N+1)), shape=(N, self.number_embeddings), dtype=theano.config.floatX))
 
-    def training_minibatch(self, batch_size):
+    def training_minibatch(self, rng, batch_size):
         # Sampling corrupted entities
         def sample_matrix():
             row = range(self.train_size+1)
-            col = self.rng.randint(0, self.number_embeddings, size=self.train_size)
+            col = rng.randint(0, self.number_embeddings, size=self.train_size)
             data = numpy.ones(self.train_size)
             random_embeddings = scipy.sparse.csr_matrix((data, col, row), shape=(self.train_size, self.number_embeddings), dtype=theano.config.floatX)
             return random_embeddings
@@ -42,7 +41,7 @@ class Dataset(object):
         corrupted_right = sample_matrix()
 
         # Shuffling training set
-        order = self.rng.permutation(self.train_size)
+        order = rng.permutation(self.train_size)
         train_left = self.train_left[order, :]
         train_right = self.train_right[order, :]
         train_relation = self.train_relation[order, :]
diff --git a/meta_model.py b/meta_model.py
@@ -4,6 +4,7 @@ from utils.log import *
 from config import *
 from model import *
 import numpy
+import threading
 
 class Meta_model(object):
     """ Meta-model class. """
@@ -20,6 +21,10 @@ class Meta_model(object):
         for model in self.models:
             model.build_test()
 
+    def build_train(self):
+        for model in self.models:
+            model.build_train()
+
     def left_scoring_function(self, relation, left, right):
         res = [ model.left_scoring_function(relation, left, right) for model in self.models ]
         return numpy.transpose(res).reshape(right.shape[0], len(self.models))
@@ -28,10 +33,10 @@ class Meta_model(object):
         res = [ model.right_scoring_function(relation, left, right) for model in self.models ]
         return numpy.transpose(res).reshape(left.shape[0], len(self.models))
 
-    def error(self):
+    def error(self, name):
         """ Compute the mean rank, standard deviation and top 10 on a given data. """
         result = []
-        for (relation, left, right) in self.dataset.iterate('test'):
+        for (relation, left, right) in self.dataset.iterate(name):
             entities = self.dataset.universe
             raw_left_scores = self.left_scoring_function(relation, left, entities)
             raw_right_scores = self.right_scoring_function(relation, entities, right)
@@ -48,5 +53,16 @@ class Meta_model(object):
     def test(self):
         """ Test the model. """
         log('# Testing the model')
-        (mean, std, top10) = self.error()
+        (mean, std, top10) = self.error('test')
         log(' mean: {0:<15} std: {1:<15} top10: {2:<15}\n'.format(mean, std, top10))
+
+    def train(self):
+        """ Train the model. """
+        threads = [ threading.Thread(target=model.train, args=()) for model in self.models ]
+        log('# Starting thread for model {0}\n'.format(model.tag))
+        for thread in threads:
+            thread.start()
+        log('# Waiting for children to join\n')
+        for thread in threads:
+            thread.join()
+        log('# All children joined\n')
diff --git a/model.py b/model.py
@@ -56,6 +56,7 @@ class Model(object):
     def save(self, filepath):
         """ Save the model in a file. """
         with open(filepath, 'wb') as file:
+            cPickle.dump(self.epoch, file, -1)
             cPickle.dump(self.embeddings, file, -1)
             cPickle.dump(self.relations, file, -1)
 
@@ -125,7 +126,7 @@ class Model(object):
         number_epoch = self.config['number of epoch']
 
         while self.epoch < number_epoch:
-            for (relation, left_positive, right_positive, left_negative, right_negative) in self.dataset.training_minibatch(batch_size):
+            for (relation, left_positive, right_positive, left_negative, right_negative) in self.dataset.training_minibatch(self.config['rng'], batch_size):
                 self.normalise_function()
                 self.train_function(relation, left_positive, right_positive, left_negative, right_negative)
 
@@ -152,28 +153,26 @@ class Model(object):
 
     def validate(self):
         """ Validate the model. """
-        log('Validation epoch {:<5}'.format(self.epoch))
+        log('Validation model "{0}" epoch {1:<5}: begin\n'.format(self.tag, self.epoch))
         (valid_mean, valid_std, valid_top10) = self.error('valid')
-        log(' valid mean: {0:<15} valid std: {1:<15} valid top10: {2:<15}'.format(valid_mean, valid_std, valid_top10))
+        log('Validation model "{0}" epoch {1:<5}: mean: {2:<15} valid std: {3:<15} valid top10: {4:<15}\n'.format(self.tag, self.epoch, valid_mean, valid_std, valid_top10))
         datalog(self.config['datalog path']+'/'+self.config['model name'], self.epoch, valid_mean, valid_std, valid_top10)
         if not hasattr(self, 'best_mean') or valid_mean < self.best_mean:
             self.best_mean = valid_mean
-            log('(best so far')
             if self.config['save best model']:
-                log(', saving...')
+                log('Validation model "{0}" epoch {1:<5}: best model so far, saving...\n'.format(self.tag, self.epoch))
                 self.save('{0}/{1}.best'.format(self.config['best model save path'], self.config['model name']))
-                log(' done')
-            log(')')
+                log('Validation model "{0}" epoch {1:<5}: saved\n'.format(self.tag, self.epoch))
 
         if self.config['validate on training data']:
             (train_mean, train_std, train_top10) = self.error('train')
-            log(' train mean: {0:<15} std: {1:<15} train top10: {2:<15}'.format(train_mean, train_std, train_top10))
-        log('\n')
+            log('Validation model "{0}" epoch {1:<5} train mean: {2:<15} std: {3:<15} train top10: {4:<15}\n'.format(self.tag, self.epochtrain_mean, train_std, train_top10))
 
     def test(self):
         """ Test the model. """
-        log('# Testing the model "{0}"'.format(self.tag))
+        log('# Test model "{0}": begin\n'.format(self.tag))
         (mean, std, top10) = self.error('test')
-        log(' mean: {0:<15} std: {1:<15} top10: {2:<15} (saving...'.format(mean, std, top10))
+        log('# Test model "{0}": mean: {1:<15} std: {2:<15} top10: {3:<15}\n'.format(self.tag, mean, std, top10))
+        log('# Test model "{0}": saving...\n'.format(self.tag))
         self.save('{0}/{1}.last'.format(self.config['last model save path'], self.config['model name']))
-        log(' done)\n')
+        log('# Test model "{0}": saved\n'.format(self.tag))
diff --git a/test.py b/test.py
@@ -20,14 +20,15 @@ if __name__ == '__main__':
     else: model_pathes = sys.argv[3]
 
     config = load_config(config_path)
-    if not config.get('meta', False) and model_pathes is None:
-        model_pathes = '{0}/{1}.best'.format(config['best model save path'], config['model name'])
+    if model_pathes is None:
+        if not config.get('meta', False):
+            model_pathes = '{0}/{1}.best'.format(config['best model save path'], config['model name'])
     if not config.get('meta', False) and isinstance(model_pathes, list):
         print('Error: multiple model specified while running in single mode', file=sys.stderr)
         sys.exit(1)
     ModelType = Meta_model if config.get('meta', False) else Model
 
-    data = Dataset(data, config['rng'])
+    data = Dataset(data)
     model = ModelType(data, config, model_pathes)
     model.build_test()
     model.test()
diff --git a/train.py b/train.py
@@ -5,20 +5,30 @@ import sys
 
 from dataset import *
 from model import *
-from relations import *
+from meta_model import *
 from config import *
 
 if __name__ == '__main__':
     if len(sys.argv)<3:
-        print('Usage: {0} data config [model]'.format(sys.argv[0]), file=sys.stderr)
+        print('Usage: {0} data config [models]'.format(sys.argv[0]), file=sys.stderr)
         sys.exit(1)
     data = sys.argv[1]
     config_path = sys.argv[2]
-    model_path = None if len(sys.argv)<4 else sys.argv[3]
+
+    if len(sys.argv)<4: model_pathes = None
+    elif len(sys.argv)>4: model_pathes = sys.argv[3:]
+    else: model_pathes = sys.argv[3]
 
     config = load_config(config_path)
-    data = Dataset(data, config['rng'])
-    model = Model(data, config, model_path)
+    if config.get('meta', False) and len(sys.argv)<4:
+        model_pathes = [ None ] * config['size']
+    if not config.get('meta', False) and isinstance(model_pathes, list):
+        print('Error: multiple model specified while running in single mode', file=sys.stderr)
+        sys.exit(1)
+
+    ModelType = Meta_model if config.get('meta', False) else Model
+    data = Dataset(data)
+    model = ModelType(data, config, model_pathes)
 
     model.build_train()
     model.build_test()