commit eee5f1e3d07f49043ce001b70af86095158a5bc6
parent 744de49cdcd00d5bead21197cc31bf226cdb03c0
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 16 Apr 2014 17:30:41 +0200
Genericise and fix some bugs
Diffstat:
4 files changed, 27 insertions(+), 21 deletions(-)
diff --git a/dataset.py b/dataset.py
@@ -73,3 +73,4 @@ class Dataset(object):
         entities = scipy.sparse.eye(N, format='csr', dtype=theano.config.floatX)
         for i in xrange(N/batch_size):
             yield entities[i*batch_size:(i+1)*batch_size]
+        yield entities[(N/batch_size)*batch_size:]
diff --git a/main.py b/main.py
@@ -1,23 +1,26 @@
 #!/usr/bin/env python2
 
+import sys
+import json
+
 from dataset import *
 from model import *
 from relations.translations import *
 
 if __name__ == '__main__':
-    hyperparameters = dict()
-    hyperparameters['similarity'] = L1_norm
-    hyperparameters['rng'] = numpy.random
-    hyperparameters['dimension'] = 20
-    hyperparameters['margin'] = 1.
-    hyperparameters['relation_learning_rate'] = 1
-    hyperparameters['embeddings_learning_rate'] = 0.1
-    hyperparameters['train_batch_size'] = 100
-    hyperparameters['test_batch_size'] = 500
-    hyperparameters['validation_frequency'] = 500
-    hyperparameters['number_epoch'] = 10000
+    if len(sys.argv)<3:
+        print >>sys.stderr, 'Usage: {0} data parameters'.format(sys.argv[0])
+        sys.exit(1)
+    data = sys.argv[1]
+    config = sys.argv[2]
+
+    with open(config, 'r') as config_file:
+        hyperparameters = json.load(config_file)
+        for k, v in hyperparameters.iteritems():
+            if isinstance(v, basestring):
+                hyperparameters[k] = eval(v)
 
-    data = Dataset('data/dummy')
-    model = Model.initialise(Translations, data, hyperparameters, 'dummy')
+    data = Dataset(data)
+    model = Model.initialise(Translations, data, hyperparameters, 'TransE')
     model.train()
     model.test()
diff --git a/model.py b/model.py
@@ -1,7 +1,5 @@
 #!/usr/bin/env python2
 
-import time
-
 import sys
 import numpy
 import scipy
@@ -122,12 +120,15 @@ class Model(object):
         for (relation, left, right) in self.dataset.iterate(name, batch_size):
             scores = None
             for entities in self.dataset.universe_minibatch(batch_size):
+                if left.shape != entities.shape:
+                    left = left[0:entities.shape[0]]
+                    relation = relation[0:entities.shape[0]]
                 batch_result = self.scoring_function(relation, left, entities)
                 scores = numpy.array(batch_result, dtype=theano.config.floatX) if scores is None else numpy.concatenate((scores, batch_result), axis=1)
-            rank = 1+numpy.where(numpy.argsort(scores)==right.indices[0])[1] # FIXME ugly
-            mean += rank
-            count += 1
-            top10 += (rank<=10)
+            rank = numpy.asscalar(numpy.where(numpy.argsort(scores)==right.indices[0])[1]) # FIXME Ugly
+            mean = mean + rank
+            count = count + 1
+            top10 = top10 + (rank<=10)
         mean = float(mean) / count
         top10 = float(top10) / count
         return (mean, top10)
@@ -136,8 +137,9 @@ class Model(object):
         """ Validate the model. """
         print >>sys.stderr, 'Validation epoch {:<5}'.format(epoch),
         (valid_mean, valid_top10) = self.error('valid')
+        print >>sys.stderr, 'valid mean: {0:<15} valid top10: {1:<15}'.format(valid_mean, valid_top10)
         (train_mean, train_top10) = self.error('train')
-        print >>sys.stderr, 'valid mean: {0:<15} valid top10: {1:<15} train mean: {0:<15} train top10: {1:<15}'.format(valid_mean, valid_top10, train_mean, train_top10)
+        print >>sys.stderr, 'train mean: {0:<15} train top10: {1:<15}'.format(train_mean, train_top10)
 
     def test(self):
         """ Test the model. """
diff --git a/utils/build Bordes FB15k.py b/utils/build Bordes FB15k.py
@@ -57,7 +57,7 @@ def compile_dataset(path):
 
     print >>sys.stderr, 'Writting entities...',
     e2i, i2e, r2i, i2r = {}, {}, {}, {}
-    with open(path+'/entities', 'w') as file:
+    with open(path+'/embeddings', 'w') as file:
         i=0
         for entity in entities:
             e2i[entity]=i