gbure

Graph-based approaches on unsupervised relation extraction evaluated as a fewshot problem
git clone https://esimon.eu/repos/gbure.git
Log | Files | Refs | README | LICENSE

utils.py (14138B)


      1 from typing import Any, Callable, Dict, List, NoReturn, Optional, Union
      2 import hashlib
      3 import importlib
      4 import logging
      5 import multiprocessing
      6 import os
      7 import pathlib
      8 import signal
      9 import subprocess
     10 import sys
     11 import time
     12 import types
     13 
     14 import torch
     15 
     16 logger = logging.getLogger(__name__)
     17 
     18 _HAS_DYNAMIC_ATTRIBUTES = True
     19 
     20 
     21 def import_environment(name: str, cast: type = str) -> None:
     22     """ Import an environment variable into the global namespace. """
     23     try:
     24         globals()[name] = cast(os.environ[name])
     25     except KeyError:
     26         print(f"ERROR: {name} environment variable is not set.", file=sys.stderr)
     27         sys.exit(1)
     28 
     29 
     30 import_environment("DATA_PATH", pathlib.Path)
     31 import_environment("LOG_PATH", pathlib.Path)
     32 
     33 
     34 class dotdict(dict):
     35     """ Dictionary which can be access through var.key instead of var["key"]. """
     36     def __getattr__(self, name: str) -> Any:
     37         if name not in self:
     38             raise AttributeError(f"Config key {name} not found")
     39         return dotdict(self[name]) if type(self[name]) is dict else self[name]
     40     __setattr__ = dict.__setitem__
     41     __delattr__ = dict.__delitem__
     42 
     43 
     44 def eval_arg(config: Dict[str, Any], arg: str) -> None:
     45     """
     46     Evaluate arg in the context config, and update it.
     47 
     48     The argument is expected to be of the form:
     49         (parent.)*key(=value)?
     50     If no value is provided, the key is assumed to be a boolean and True is assigned to it.
     51     When passing a string argument through the shell, it must be enclosed in quote (like all python string), which usually need to be escaped.
     52     """
     53     key: str
     54     value: Any
     55     if '=' in arg:
     56         key, value = arg.split('=', maxsplit=1)
     57         value = eval(value, config)
     58     else:
     59         key, value = arg, True
     60     path: List[str] = key.split('.')
     61     for component in path[:-1]:
     62         config = config[component]
     63     config[path[-1]] = value
     64     config.pop("__builtins__", None)
     65 
     66 
     67 def import_arg(config: Dict[str, Any], arg: str) -> None:
     68     """
     69     Load file arg, and update config with its content.
     70 
     71     The file is loaded in an independent context, all the variable defined in the file (even through import) are added to config, with the exception of builtins and whole modules.
     72     """
     73     if arg.endswith(".py"):
     74         arg = arg[:-3].replace('/', '.')
     75     module: types.ModuleType = importlib.import_module(arg)
     76     for key, value in vars(module).items():
     77         if key not in module.__builtins__ and not key.startswith("__") and not isinstance(value, types.ModuleType):  # pytype: disable=attribute-error
     78             config[key] = value
     79 
     80 
     81 def parse_args() -> dotdict:
     82     """
     83     Parse command line arguments and return config dictionary.
     84 
     85     Two kind of argument are supported:
     86         - When the argument starts with -- it is evaluated by the eval_arg function
     87         - Otherwise the argument is assumed to be a file which is loaded by the import_arg function
     88     """
     89     config: Dict[str, Any] = {}
     90     config["config"] = config
     91     for arg in sys.argv[1:]:
     92         if arg.startswith("--"):
     93             eval_arg(config, arg[2:])
     94         else:
     95             import_arg(config, arg)
     96     config.pop("config")
     97     return dotdict(config)
     98 
     99 
    100 def display_dict(output: Callable[[str], None], input: Dict[str, Any], depth: int = 0) -> None:
    101     """ Display nested dictionaries in input using the provided output function. """
    102     for key, value in input.items():
    103         indent = '\t'*depth
    104         output(f"{indent}{key}:")
    105         if isinstance(value, dict):
    106             output('\n')
    107             display_dict(output, value, depth+1)
    108         else:
    109             output(f" {value}\n")
    110 
    111 
    112 def print_dict(input: Dict[str, Any]) -> None:
    113     """ Print dictionary to standard output. """
    114     display_dict(lambda x: print(x, end=""), input)
    115 
    116 
    117 def log_dict(logger: logging.Logger, input: Dict[str, Any]) -> None:
    118     """ Log dictionary to the provided logger. """
    119     class log:
    120         buf: str = ""
    121 
    122         def __call__(self, x: str) -> None:
    123             self.buf += x
    124             if self.buf.endswith('\n'):
    125                 logger.info(self.buf[:-1])
    126                 self.buf = ""
    127     display_dict(log(), input)
    128 
    129 
    130 def flatten_dict(input: Dict[str, Any]) -> Dict[str, Union[bool, int, float, str]]:
    131     """
    132     Replace nested dict by dot-separated keys, and cast keys to simple types.
    133 
    134     repr() is used to cast non-base-type to str.
    135     """
    136     def impl(result: Dict[str, Union[bool, int, float, str]], input: Dict[str, Any], prefix: str):
    137         for key, value in input.items():
    138             if isinstance(value, dict):
    139                 impl(result, value, f"{key}.")
    140             else:
    141                 result[f"{prefix}{key}"] = value if type(value) in [bool, int, float, str] else repr(value)
    142 
    143     result: Dict[str, Union[bool, int, float, str]] = {}
    144     impl(result, input, "")
    145     return result
    146 
    147 
    148 def get_repo_version() -> str:
    149     """ Get the code repository version. """
    150     repo_dir = pathlib.Path(__file__).parents[0]
    151     result: subprocess.CompletedProcess = subprocess.run(
    152             ["git", "rev-parse", "HEAD"],
    153             stdout=subprocess.PIPE,
    154             stderr=subprocess.DEVNULL,
    155             encoding="utf-8",
    156             cwd=repo_dir)
    157 
    158     if result.returncode != 0:
    159         return "release"
    160     commit_hash: str = result.stdout.strip()[:8]
    161 
    162     result = subprocess.run(
    163             ["git", "status", "--porcelain"],
    164             stdout=subprocess.PIPE,
    165             stderr=subprocess.DEVNULL,
    166             encoding="utf-8",
    167             cwd=repo_dir)
    168     modified_flag: str = ""
    169     for line in result.stdout.split('\n'):
    170         if line.startswith(" M "):
    171             modified_flag = "+"
    172             break
    173 
    174     return f"{commit_hash}{modified_flag}"
    175 
    176 
    177 def experiment_name(name: str) -> str:
    178     """ Name of the experiment (contains repository version, argument and time). """
    179     args: str = ' '.join(sys.argv[1:])
    180     version: str = get_repo_version()
    181     stime: str = time.strftime("%FT%H:%M:%S")
    182     return f"{name} {version} {args} {stime}"
    183 
    184 
    185 def logdir_name(name: str) -> str:
    186     """ Name of the experiment directory, it should be the experiment_name clipped because of filesystem constraints. """
    187     subdir: str = experiment_name(name).replace('/', '_')
    188     if len(subdir) > 255:
    189         sha1: str = hashlib.sha1(subdir.encode("utf-8")).digest().hex()[:16]
    190         subdir = subdir[:255-17] + ' ' + sha1
    191     return subdir
    192 
    193 
    194 def fix_transformers_logging_handler() -> None:
    195     """ The transformers package from huggingface install its own logger on import, I don't want it. """
    196     logger: logging.Logger = logging.getLogger()
    197     for handler in logger.handlers:
    198         logger.removeHandler(handler)
    199 
    200 
    201 def add_logging_handler(logdir: pathlib.Path) -> None:
    202     logfile: pathlib.Path = logdir / "log"
    203     logging.basicConfig(format="%(asctime)s\t%(levelname)s:%(name)s:%(message)s", filename=logfile, filemode='a', level=logging.INFO)
    204 
    205 
    206 def save_patch(outpath: pathlib.Path) -> None:
    207     """ Save a file at the given patch containing the diff between the current code and the last commit. """
    208     repo_dir = pathlib.Path(__file__).parents[0]
    209 
    210     with outpath.open("w") as outfile:
    211         result: subprocess.CompletedProcess = subprocess.run(
    212                 ["git", "diff", "HEAD"],
    213                 stdout=outfile,
    214                 stderr=subprocess.DEVNULL,
    215                 encoding="utf-8",
    216                 cwd=repo_dir)
    217 
    218     assert(result.returncode == 0)
    219 
    220 
    221 class Experiment:
    222     """
    223     Base class for running an experiment.
    224 
    225     Calling run on an instance of this class will call the init() then main() functions.
    226     The sole purpose of this class is to make an experiment "prettier": it displays config values, store a diff of the repo in the experiment logdir, etc.
    227 
    228     Config:
    229         deterministic: run in deterministic mode
    230         seed: seed for random number generators
    231     """
    232 
    233     _HAS_DYNAMIC_ATTRIBUTES = True
    234 
    235     def __init__(self, config: dotdict, logdir: pathlib.Path, state_dicts: Optional[Dict[str, Any]] = None) -> None:
    236         self.config = config
    237         self.logdir = logdir
    238         self.state_dicts = state_dicts
    239 
    240     def init(self) -> None:
    241         """ Prepare the experiment (e.g. initialize datasets and models). """
    242         pass
    243 
    244     def main(self) -> NoReturn:
    245         """ Run the experiment in itself. """
    246         raise NotImplementedError("Subclasses must implement a main method")
    247 
    248     def close(self) -> None:
    249         """ Free used resources (e.g. close opened files). """
    250         pass
    251 
    252     def run(self) -> None:
    253         """ Run the whole experiment with setting ups, etc. """
    254         self.log_environment()
    255         self.log_patch()
    256         self.initialize_rng()
    257         self.init()
    258         self.hook_signals()
    259         self.main()
    260         self.close()
    261 
    262     def log_environment(self) -> None:
    263         """ Display information about the environment. """
    264         print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m")
    265         self.version_check()
    266         self.detect_gpus()
    267         print("")
    268 
    269         print("\033[1m\033[33mConfiguration\033[0m")
    270         print_dict(self.config)
    271         log_dict(logging.getLogger("config"), self.config)
    272         print("")
    273 
    274     def version_check(self) -> None:
    275         """ Check the version of the main dependencies. """
    276         python_version: str = '.'.join(map(str, sys.version_info[:3]))
    277         torch_version: str = torch.__version__
    278         cuda_available: str = str(torch.cuda.is_available())
    279 
    280         logger.info(f"python version {python_version}")
    281         logger.info(f"torch version {torch_version}")
    282         logger.info(f"cuda available {cuda_available}")
    283 
    284         def problem(msg: str) -> str:
    285             return f"\033[1m\033[31m{msg}\033[0m"
    286         if sys.version_info < (3, 7):
    287             python_version = problem(python_version)
    288         if list(map(int, torch_version.split('+')[0].split('.'))) < [1, 6]:
    289             torch_version = problem(torch_version)
    290         if cuda_available == "False":
    291             cuda_available = problem(cuda_available)
    292         print(f"python version: {python_version}, torch version: {torch_version}, cuda available: {cuda_available}")
    293 
    294     def detect_gpus(self) -> None:
    295         """ Display available gpus and set self.device. """
    296         count: int = torch.cuda.device_count()
    297 
    298         if count == 0:
    299             print(f"\033[1m\033[31mNo GPU available\033[0m")
    300             logger.warning("no GPU available")
    301             self.device = torch.device("cpu")
    302         else:
    303             self.device = torch.device("cuda:0")
    304 
    305         for i in range(count):
    306             gp = torch.cuda.get_device_properties(i)
    307             print(f"GPU{i}: \033[33m{gp.name}\033[0m (Mem: {gp.total_memory/2**30:.2f}GiB CC: {gp.major}.{gp.minor})")
    308             logger.info(f"GPU{i} {gp.name} {gp.total_memory} {gp.major}.{gp.minor}")
    309 
    310     def log_patch(self) -> None:
    311         """ Check the version of the code, and save a patch to logpath if it was modified. """
    312         version: str = get_repo_version()
    313         logger.info(f"repository_version {version}")
    314         if version == "release":
    315             print(f"\033[41mRelease version\033[0m\n")
    316         elif version.endswith('+'):
    317             print(f"\033[31mUncommited changes detected, saving patch to logdir.\033[0m\n")
    318             suffix: str = ""
    319             if self.state_dicts:  # Reloading an existing Trainer
    320                 suffix = time.strftime("%FT%H:%M:%S")
    321             save_patch(self.logdir / f"patch{suffix}")
    322 
    323     def initialize_rng(self) -> None:
    324         if self.state_dicts and "torch_rng" in self.state_dicts:
    325             torch.random.set_rng_state(self.state_dicts["torch_rng"])
    326             assert(("cuda_rng" in self.state_dicts) == torch.cuda.is_available())
    327             if "cuda_rng" in self.state_dicts:
    328                 torch.cuda.random.set_rng_state_all(self.state_dicts["cuda_rng"])
    329         else:
    330             torch.manual_seed(self.config.seed)
    331 
    332         if self.config.get("deterministic") and torch.backends.cudnn.enabled:
    333             torch.backends.cudnn.deterministic = True
    334             torch.backends.cudnn.benchmark = False
    335 
    336     def hook_signals(self) -> None:
    337         """ Change the behavior of SIGINT (^C) to change a variable `self.interrupted' before killing the process. """
    338         self.interrupted: bool = False
    339 
    340         def handler(sig: int, frame: types.FrameType) -> None:
    341             if multiprocessing.current_process().name != "MainProcess":
    342                 return
    343 
    344             print("\n\033[31mInterrupted, execution will stop at the end of this epoch.\n\033[1mNEXT ^C WILL KILL THE PROCESS!\033[0m\n", file=sys.stderr)
    345             self.interrupted = True
    346             signal.signal(signal.SIGINT, signal.SIG_DFL)
    347 
    348         signal.signal(signal.SIGINT, handler)
    349 
    350 
    351 class SharedLongTensorList:
    352     def __init__(self, tensor_list: List[torch.Tensor], view: List[int] = [-1]):
    353         self.view = view
    354 
    355         total_element: int = 0
    356         for tensor in tensor_list:
    357             total_element += tensor.numel()
    358         self.data: torch.Tensor = torch.empty(total_element, dtype=torch.int64)
    359         self.indices: torch.Tensor = torch.empty(len(tensor_list)+1, dtype=torch.int64)
    360 
    361         data_pos: int = total_element
    362         indices_pos: int = len(tensor_list)
    363         self.indices[indices_pos] = data_pos
    364         while tensor_list:
    365             tensor: torch.Tensor = tensor_list.pop()
    366             tensor_size: int = tensor.numel()
    367 
    368             indices_pos -= 1
    369             data_pos -= tensor_size
    370 
    371             self.indices[indices_pos] = data_pos
    372             self.data[data_pos:data_pos+tensor_size] = tensor.flatten()
    373         assert(data_pos == 0)
    374         assert(indices_pos == 0)
    375 
    376     def __len__(self) -> int:
    377         return self.indices.shape[0]-1
    378 
    379     def __getitem__(self, key: Union[int, slice]) -> torch.Tensor:
    380         if isinstance(key, slice):
    381             return [self[value] for value in range(*key.indices(len(self)))]
    382         elif isinstance(key, int):
    383             return self.data[self.indices[key]:self.indices[key+1]].view(*self.view)
    384         elif isinstance(key, torch.Tensor):
    385             return self[key.item()]
    386         else:
    387             raise TypeError("Invalid argument type.")