commit 9abf658fda3a8d55c9a1edfe44f7f0617cd086b4
parent 0f5f3fdc2a50c5dfecf7f71bd9b3cf60a9fb6eee
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 16 Apr 2014 18:11:21 +0200
Fix training set splitting for minibatch
Diffstat:
2 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/dataset.py b/dataset.py
@@ -48,11 +48,13 @@ class Dataset(object):
         # Yielding batches
         ls = numpy.linspace(0, self.train_size, 1+self.train_size/batch_size)
         for i in xrange(len(ls)-1):
-            left_positive = train_left[ls[i]:ls[i+1]]
-            right_positive = train_right[ls[i]:ls[i+1]]
-            left_negative = corrupted_left[ls[i]:ls[i+1]]
-            right_negative = corrupted_right[ls[i]:ls[i+1]]
-            relation = train_relation[ls[i]:ls[i+1]]
+            f = int(ls[i])
+            t = int(ls[i+1])
+            left_positive = train_left[f:t]
+            right_positive = train_right[f:t]
+            left_negative = corrupted_left[f:t]
+            right_negative = corrupted_right[f:t]
+            relation = train_relation[f:t]
             yield (relation, left_positive, right_positive, left_negative, right_negative)
 
     def iterate(self, name, batch_size):
diff --git a/model.py b/model.py
@@ -117,7 +117,7 @@ class Model(object):
         """ Compute the mean rank and top 10 on a given data. """
         batch_size = self.hyperparameters['test_batch_size']
         count, mean, top10 = 0, 0, 0
-        for (relation, left, right) in self.dataset.iterate(name, batch_size):
+        for (relation, left, right) in self.dataset.iterate(name, batch_size): # TODO Test symmetric
             scores = None
             for entities in self.dataset.universe_minibatch(batch_size):
                 if left.shape != entities.shape: