commit 1797f9e475272ed8f3851155698fe27d942aed0a
parent 64d0763dbc5bee1bdea4aad867f2ecb4bac5ace9
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 18 Jun 2014 14:10:54 +0200
Add relations analyzer tool
Diffstat:
2 files changed, 54 insertions(+), 3 deletions(-)
diff --git a/dataset.py b/dataset.py
@@ -6,13 +6,13 @@ import numpy
 import theano
 
 class Dataset(object):
-    def __init__(self, prefix, config):
+    def __init__(self, prefix, config={}):
         self.config = config
         log('# Loading dataset "{0}"\n'.format(prefix))
         with open(prefix+'/entities', 'r') as file:
-            self.entities = file.readlines()
+            self.entities = file.read().splitlines()
         with open(prefix+'/relations', 'r') as file:
-            self.relations = file.readlines()
+            self.relations = file.read().splitlines()
         self.number_entities = len(self.entities)
         self.number_relations = len(self.relations)
         self.load_file(prefix, 'train')
diff --git a/utils/relations analyzer.py b/utils/relations analyzer.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python2
+
+from __future__ import print_function
+import sys
+
+from dataset import *
+
+if __name__ == '__main__':
+    if len(sys.argv)<2:
+        print('Usage: {0} dataset'.format(sys.argv[0]), file=sys.stderr)
+        sys.exit(1)
+    data = Dataset(sys.argv[1])
+
+    print('# Splitting dataset')
+    lefts, rights = [[] for _ in xrange(data.number_relations)], [[] for _ in xrange(data.number_relations)]
+    for relation, left, right in data.iterate("test"):
+        left, relation, right = left.indices[0], relation.indices[0], right.indices[0]
+        lefts[relation].append(left)
+        rights[relation].append(right)
+
+    total_injective, total_functional, total_onetoone, total_irreflexive = 0, 0, 0, 0
+    print(' '*123+'injective functional 1-to-1 irreflexive')
+    for index, relation in enumerate(data.relations):
+        left = lefts[index]
+        right = rights[index]
+        preimage = set(left)
+        image = set(right)
+        injective = len(image) == len(right)
+        functional = len(preimage) == len(left)
+        onetoone = injective & functional
+
+        irreflexive = True
+        for l,r in zip(left, right):
+            if l==r:
+                irreflexive = False
+
+        intersection = image & preimage
+
+        print('{0:<120} : {1:<9} {2:<10} {3:<6} {4:<11}'.format(relation, injective, functional, onetoone, irreflexive))
+
+        total_injective += injective
+        total_functional += functional
+        total_irreflexive += irreflexive
+        total_onetoone += onetoone
+    
+    N = float(len(data.relations))
+    print('')
+    print('Total injective: {0} ({1})'.format(total_injective/N, total_injective))
+    print('Total functional: {0} ({1})'.format(total_functional/N, total_functional))
+    print('Total 1-to-1: {0} ({1})'.format(total_onetoone/N, total_onetoone))
+    print('Total irreflexive: {0} ({1})'.format(total_irreflexive/N, total_irreflexive))