commit c2c88c48a0404de0eb834df71fa53ae63fdfd1c7
parent 6c45eb6e48775dcbbbd3177f02c1d1b0c161ba1e
Author: Alex Auvolat <alex.auvolat@ens.fr>
Date: Fri, 22 May 2015 15:50:07 -0400
Delete useless file apply_model.py
Diffstat:
1 file changed, 0 insertions(+), 43 deletions(-)
diff --git a/apply_model.py b/apply_model.py
@@ -1,43 +0,0 @@
-import theano
-
-from blocks.graph import ComputationGraph
-
-class Apply(object):
- def __init__(self, outputs, return_vars, stream):
- if not isinstance(outputs, list):
- outputs = [outputs]
- if not isinstance(return_vars, list):
- return_vars = [return_vars]
-
- self.outputs = outputs
- self.return_vars = return_vars
- self.stream = stream
-
- cg = ComputationGraph(self.outputs)
- self.input_names = [i.name for i in cg.inputs]
- self.f = theano.function(inputs=cg.inputs, outputs=self.outputs)
-
- def __iter__(self):
- self.iterator = self.stream.get_epoch_iterator(as_dict=True)
- while True:
- try:
- batch = next(self.iterator)
- except StopIteration:
- return
-
- inputs = [batch[n] for n in self.input_names]
- outputs = self.f(*inputs)
-
- def find_retvar(name):
- for idx, ov in enumerate(self.outputs):
- if ov.name == name:
- return outputs[idx]
-
- if name in batch:
- return batch[name]
-
- raise ValueError('Variable ' + name + ' neither in outputs or in batch variables.')
-
- yield {name: find_retvar(name) for name in self.return_vars}
-
-