commit 4ccc1084a6b4cff5eed0ab7b004a3eb012153e2a
parent b1e638bb2a94157b60866c12a0538457b83d6c59
Author: Étienne Simon <esimon@esimon.eu>
Date: Fri, 18 Apr 2014 14:07:21 +0200
Fix train function
Diffstat:
M | model.py | | | 32 | +++++++++++++++++++------------- |
1 file changed, 19 insertions(+), 13 deletions(-)
diff --git a/model.py b/model.py
@@ -79,23 +79,30 @@ class Model(object):
self.parameters = self.relations.parameters + self.embeddings.parameters
inputs = tuple(S.csr_matrix() for _ in xrange(5))
- positive_left, positive_right = self.embeddings.embed(inputs[1]), self.embeddings.embed(inputs[2])
- negative_left, negative_right = self.embeddings.embed(inputs[3]), self.embeddings.embed(inputs[4])
+ left_positive, right_positive = self.embeddings.embed(inputs[1]), self.embeddings.embed(inputs[2])
+ left_negative, right_negative = self.embeddings.embed(inputs[3]), self.embeddings.embed(inputs[4])
relation = self.relations.lookup(inputs[0])
- positive_score = self.hyperparameters['similarity'](self.relations.apply(positive_left, relation), positive_right)
- negative_score = self.hyperparameters['similarity'](self.relations.apply(negative_left, relation), negative_right)
- score = self.hyperparameters['margin'] + positive_score - negative_score
- violating_margin = score>0
- criterion = T.mean(violating_margin*score)
+
+ score_positive = self.hyperparameters['similarity'](self.relations.apply(left_positive, relation), right_positive)
+ score_left_negative = self.hyperparameters['similarity'](self.relations.apply(left_negative, relation), right_positive)
+ score_right_negative = self.hyperparameters['similarity'](self.relations.apply(left_positive, relation), right_negative)
+ score_left = self.hyperparameters['margin'] + score_positive - score_left_negative
+ score_right = self.hyperparameters['margin'] + score_positive - score_right_negative
+
+ violating_margin_left = score_left>0
+ violating_margin_right = score_right>0
+ criterion_left = T.mean(violating_margin_left*score_left)
+ criterion_right = T.mean(violating_margin_right*score_right)
+ criterion = criterion_left + criterion_right
self.train_function = theano.function(inputs=list(inputs), outputs=[criterion], updates=self.updates(criterion))
self.normalise_function = theano.function(inputs=[], outputs=[], updates=self.embeddings.normalise_updates())
relation = T.addbroadcast(relation, 0)
- broadcasted_left = T.addbroadcast(positive_left, 0)
- broadcasted_right = T.addbroadcast(positive_right, 0)
- left_score = self.hyperparameters['similarity'](self.relations.apply(broadcasted_left, relation), positive_right)
- right_score = self.hyperparameters['similarity'](self.relations.apply(positive_left, relation), broadcasted_right)
+ left_broadcasted = T.addbroadcast(left_positive, 0)
+ right_broadcasted = T.addbroadcast(right_positive, 0)
+ left_score = self.hyperparameters['similarity'](self.relations.apply(left_broadcasted, relation), right_positive)
+ right_score = self.hyperparameters['similarity'](self.relations.apply(left_positive, relation), right_broadcasted)
self.left_scoring_function = theano.function(inputs=list(inputs[0:3]), outputs=[left_score])
self.right_scoring_function = theano.function(inputs=list(inputs[0:3]), outputs=[right_score])
@@ -123,8 +130,7 @@ class Model(object):
self.validate(epoch)
for (relation, left_positive, right_positive, left_negative, right_negative) in self.dataset.training_minibatch(batch_size):
- c1=self.train_function(relation, left_positive, right_positive, left_positive, right_negative)
- c2=self.train_function(relation, left_positive, right_positive, left_negative, right_positive)
+ self.train_function(relation, left_positive, right_positive, left_negative, right_negative)
self.normalise_function()
def error(self, name):