eval.py (909B)
1 from typing import Any, Dict 2 import pathlib 3 4 import torch 5 6 import gbure.train 7 import gbure.utils 8 9 10 class Evaluator(gbure.train.Trainer): 11 """ 12 Evaluate a model. 13 14 Config: 15 valid: only evualte on validation split 16 test: only evualte on test split 17 """ 18 19 def main(self) -> None: 20 """ Run the experiment (i.e. here, evaluate). """ 21 if self.config.get("valid") or not self.config.get("test"): 22 self.evaluate("valid") 23 if self.config.get("test") or not self.config.get("valid"): 24 self.evaluate("test") 25 26 27 if __name__ == "__main__": 28 gbure.utils.fix_transformers_logging_handler() 29 config: gbure.utils.dotdict = gbure.utils.parse_args() 30 31 state_dicts: Dict[str, Any] = torch.load(config.load) 32 logdir: pathlib.Path = state_dicts["logdir"] 33 gbure.utils.add_logging_handler(logdir) 34 Evaluator(config, logdir, state_dicts).run()