commit 996d461a6fb8dfe4da9c8298474ec5e87a144c7e
parent 748e9a2e558822a7ad9849dcef4b01df5eb77259
Author: Étienne Simon <esimon@esimon.eu>
Date:   Fri, 18 Apr 2014 15:48:32 +0200
Rename config
Diffstat:
| M | main.py |  |  | 18 | +++++++++--------- | 
| M | model.py |  |  | 50 | +++++++++++++++++++++++++------------------------- | 
2 files changed, 34 insertions(+), 34 deletions(-)
diff --git a/main.py b/main.py
@@ -11,23 +11,23 @@ from relations import *
 
 if __name__ == '__main__':
     if len(sys.argv)<3:
-        print('Usage: {0} data parameters [model]\n'.format(sys.argv[0]), file=sys.stderr)
+        print('Usage: {0} data parameters [model]'.format(sys.argv[0]), file=sys.stderr)
         sys.exit(1)
     data = sys.argv[1]
-    config = sys.argv[2]
+    config_path = sys.argv[2]
     model_path = None if len(sys.argv)<4 else sys.argv[3]
 
-    with open(config, 'r') as config_file:
-        hyperparameters = json.load(config_file)
-        for k, v in hyperparameters.iteritems():
+    with open(config_path, 'r') as config_file:
+        config = json.load(config_file)
+        for k, v in config.iteritems():
             if isinstance(v, basestring) and v.startswith('python:'):
-                hyperparameters[k] = eval(v[7:])
-    datalog_filepath = hyperparameters['datalog filepath']
+                config[k] = eval(v[7:])
+    datalog_filepath = config['datalog filepath']
 
     data = Dataset(data)
     if model_path is None:
-        model = Model.initialise(hyperparameters['relations'], data, hyperparameters, 'TransE')
+        model = Model.initialise(config['relations'], data, config, config['model name'])
     else:
-        model = Model.load(model_path, data, hyperparameters, 'TransE')
+        model = Model.load(model_path, data, config, config['model name'])
     model.train()
     model.test()
diff --git a/model.py b/model.py
@@ -23,35 +23,35 @@ class Model(object):
     """
 
     @classmethod
-    def initialise(cls, Relations, dataset, hyperparameters, tag):
+    def initialise(cls, Relations, dataset, config, tag):
         """ Initialise a model.
 
         Keyword arguments:
         Relations -- relations class
         dataset -- dataset on which the model will be trained and tested
-        hyperparameters -- hyperparameters dictionary
+        config -- config dictionary
         tag -- name of the embeddings for parameter declaration
         """
         log('# Initialising model "{0}"\n'.format(tag))
 
         self = cls()
-        self.embeddings = Embeddings(hyperparameters['rng'], dataset.number_embeddings, hyperparameters['dimension'], tag+'.embeddings')
-        self.relations = Relations(hyperparameters['rng'], dataset.number_relations, hyperparameters['dimension'], tag+'.relations')
+        self.embeddings = Embeddings(config['rng'], dataset.number_embeddings, config['dimension'], tag+'.embeddings')
+        self.relations = Relations(config['rng'], dataset.number_relations, config['dimension'], tag+'.relations')
         self.dataset = dataset
-        self.hyperparameters = hyperparameters
+        self.config = config
         self.tag = tag
 
         self.build()
         return self
 
     @classmethod
-    def load(cls, filepath, dataset, hyperparameters, tag):
+    def load(cls, filepath, dataset, config, tag):
         """ Load a model from a file.
 
         Keyword arguments:
         filepath -- path to the Model file
         dataset -- dataset on which the model will be trained and tested
-        hyperparameters -- hyperparameters dictionary
+        config -- config dictionary
         tag -- name of the embeddings for parameter declaration
         """
         log('# Loading model from "{0}"\n'.format(filepath))
@@ -61,7 +61,7 @@ class Model(object):
             self.embeddings = cPickle.load(file)
             self.relations = cPickle.load(file)
         self.dataset = dataset;
-        self.hyperparameters = hyperparameters;
+        self.config = config;
         self.tag = tag
 
         self.build()
@@ -83,11 +83,11 @@ class Model(object):
         left_negative, right_negative = self.embeddings.embed(inputs[3]), self.embeddings.embed(inputs[4])
         relation = self.relations.lookup(inputs[0])
 
-        score_positive = self.hyperparameters['similarity'](self.relations.apply(left_positive, relation), right_positive)
-        score_left_negative = self.hyperparameters['similarity'](self.relations.apply(left_negative, relation), right_positive)
-        score_right_negative = self.hyperparameters['similarity'](self.relations.apply(left_positive, relation), right_negative)
-        score_left = self.hyperparameters['margin'] + score_positive - score_left_negative
-        score_right = self.hyperparameters['margin'] + score_positive - score_right_negative
+        score_positive = self.config['similarity'](self.relations.apply(left_positive, relation), right_positive)
+        score_left_negative = self.config['similarity'](self.relations.apply(left_negative, relation), right_positive)
+        score_right_negative = self.config['similarity'](self.relations.apply(left_positive, relation), right_negative)
+        score_left = self.config['margin'] + score_positive - score_left_negative
+        score_right = self.config['margin'] + score_positive - score_right_negative
 
         violating_margin_left = score_left>0
         violating_margin_right = score_right>0
@@ -101,8 +101,8 @@ class Model(object):
         relation = T.addbroadcast(relation, 0)
         left_broadcasted = T.addbroadcast(left_positive, 0)
         right_broadcasted = T.addbroadcast(right_positive, 0)
-        left_score = self.hyperparameters['similarity'](self.relations.apply(left_broadcasted, relation), right_positive)
-        right_score = self.hyperparameters['similarity'](self.relations.apply(left_positive, relation), right_broadcasted)
+        left_score = self.config['similarity'](self.relations.apply(left_broadcasted, relation), right_positive)
+        right_score = self.config['similarity'](self.relations.apply(left_positive, relation), right_broadcasted)
 
         self.left_scoring_function = theano.function(inputs=list(inputs[0:3]), outputs=[left_score])
         self.right_scoring_function = theano.function(inputs=list(inputs[0:3]), outputs=[right_score])
@@ -113,17 +113,17 @@ class Model(object):
         Keyword arguments:
         cost -- The cost to optimise.
         """
-        lr_relations = self.hyperparameters['relation learning rate']
-        lr_embeddings = self.hyperparameters['embeddings learning rate']
+        lr_relations = self.config['relation learning rate']
+        lr_embeddings = self.config['embeddings learning rate']
         return self.relations.updates(cost, lr_relations) + self.embeddings.updates(cost, lr_embeddings)
 
     def train(self):
         """ Train the model. """
         log('# Training the model "{0}"\n'.format(self.tag))
 
-        batch_size = self.hyperparameters['train batch size']
-        validation_frequency = self.hyperparameters['validation frequency']
-        number_epoch = self.hyperparameters['number of epoch']
+        batch_size = self.config['train batch size']
+        validation_frequency = self.config['validation frequency']
+        number_epoch = self.config['number of epoch']
 
         for epoch in xrange(number_epoch):
             if epoch % validation_frequency == 0:
@@ -135,7 +135,7 @@ class Model(object):
 
     def error(self, name):
         """ Compute the mean rank and top 10 on a given data. """
-        batch_size = self.hyperparameters['test batch size']
+        batch_size = self.config['test batch size']
         count, mean, top10 = 0, 0, 0
         for (relation, left, right) in self.dataset.iterate(name):
             left_scores, right_scores = None, None
@@ -168,13 +168,13 @@ class Model(object):
         if not hasattr(self, 'best_mean') or valid_mean < self.best_mean:
             self.best_mean = valid_mean
             log('(best so far')
-            if self.hyperparameters['save best model']:
+            if self.config['save best model']:
                 log(', saving...')
-                self.save(self.hyperparameters['best model save location'])
+                self.save(self.config['best model save location'])
                 log(' done')
             log(')')
 
-        if self.hyperparameters['validate on training data']:
+        if self.config['validate on training data']:
             (train_mean, train_top10) = self.error('train')
             log(' train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10))
         log('\n')
@@ -184,5 +184,5 @@ class Model(object):
         log('# Testing the model "{0}"'.format(self.tag))
         (mean, top10) = self.error('test')
         log(' mean: {0:<15} top10: {1:<15} (saving...'.format(mean, top10))
-        self.save(self.hyperparameters['last model save location'])
+        self.save(self.config['last model save location'])
         log(' done)\n')