commit 4ccc1084a6b4cff5eed0ab7b004a3eb012153e2a
parent b1e638bb2a94157b60866c12a0538457b83d6c59
Author: Étienne Simon <esimon@esimon.eu>
Date:   Fri, 18 Apr 2014 14:07:21 +0200
Fix train function
Diffstat:
| M | model.py |  |  | 32 | +++++++++++++++++++------------- | 
1 file changed, 19 insertions(+), 13 deletions(-)
diff --git a/model.py b/model.py
@@ -79,23 +79,30 @@ class Model(object):
 
         self.parameters = self.relations.parameters + self.embeddings.parameters
         inputs = tuple(S.csr_matrix() for _ in xrange(5))
-        positive_left, positive_right = self.embeddings.embed(inputs[1]), self.embeddings.embed(inputs[2])
-        negative_left, negative_right = self.embeddings.embed(inputs[3]), self.embeddings.embed(inputs[4])
+        left_positive, right_positive = self.embeddings.embed(inputs[1]), self.embeddings.embed(inputs[2])
+        left_negative, right_negative = self.embeddings.embed(inputs[3]), self.embeddings.embed(inputs[4])
         relation = self.relations.lookup(inputs[0])
-        positive_score = self.hyperparameters['similarity'](self.relations.apply(positive_left, relation), positive_right)
-        negative_score = self.hyperparameters['similarity'](self.relations.apply(negative_left, relation), negative_right)
-        score = self.hyperparameters['margin'] + positive_score - negative_score
-        violating_margin = score>0
-        criterion = T.mean(violating_margin*score)
+
+        score_positive = self.hyperparameters['similarity'](self.relations.apply(left_positive, relation), right_positive)
+        score_left_negative = self.hyperparameters['similarity'](self.relations.apply(left_negative, relation), right_positive)
+        score_right_negative = self.hyperparameters['similarity'](self.relations.apply(left_positive, relation), right_negative)
+        score_left = self.hyperparameters['margin'] + score_positive - score_left_negative
+        score_right = self.hyperparameters['margin'] + score_positive - score_right_negative
+
+        violating_margin_left = score_left>0
+        violating_margin_right = score_right>0
+        criterion_left = T.mean(violating_margin_left*score_left)
+        criterion_right = T.mean(violating_margin_right*score_right)
+        criterion = criterion_left + criterion_right
 
         self.train_function = theano.function(inputs=list(inputs), outputs=[criterion], updates=self.updates(criterion))
         self.normalise_function = theano.function(inputs=[], outputs=[], updates=self.embeddings.normalise_updates())
 
         relation = T.addbroadcast(relation, 0)
-        broadcasted_left = T.addbroadcast(positive_left, 0)
-        broadcasted_right = T.addbroadcast(positive_right, 0)
-        left_score = self.hyperparameters['similarity'](self.relations.apply(broadcasted_left, relation), positive_right)
-        right_score = self.hyperparameters['similarity'](self.relations.apply(positive_left, relation), broadcasted_right)
+        left_broadcasted = T.addbroadcast(left_positive, 0)
+        right_broadcasted = T.addbroadcast(right_positive, 0)
+        left_score = self.hyperparameters['similarity'](self.relations.apply(left_broadcasted, relation), right_positive)
+        right_score = self.hyperparameters['similarity'](self.relations.apply(left_positive, relation), right_broadcasted)
 
         self.left_scoring_function = theano.function(inputs=list(inputs[0:3]), outputs=[left_score])
         self.right_scoring_function = theano.function(inputs=list(inputs[0:3]), outputs=[right_score])
@@ -123,8 +130,7 @@ class Model(object):
                 self.validate(epoch)
 
             for (relation, left_positive, right_positive, left_negative, right_negative) in self.dataset.training_minibatch(batch_size):
-                c1=self.train_function(relation, left_positive, right_positive, left_positive, right_negative)
-                c2=self.train_function(relation, left_positive, right_positive, left_negative, right_positive)
+                self.train_function(relation, left_positive, right_positive, left_negative, right_negative)
                 self.normalise_function()
 
     def error(self, name):