utils.py (14138B)
1 from typing import Any, Callable, Dict, List, NoReturn, Optional, Union 2 import hashlib 3 import importlib 4 import logging 5 import multiprocessing 6 import os 7 import pathlib 8 import signal 9 import subprocess 10 import sys 11 import time 12 import types 13 14 import torch 15 16 logger = logging.getLogger(__name__) 17 18 _HAS_DYNAMIC_ATTRIBUTES = True 19 20 21 def import_environment(name: str, cast: type = str) -> None: 22 """ Import an environment variable into the global namespace. """ 23 try: 24 globals()[name] = cast(os.environ[name]) 25 except KeyError: 26 print(f"ERROR: {name} environment variable is not set.", file=sys.stderr) 27 sys.exit(1) 28 29 30 import_environment("DATA_PATH", pathlib.Path) 31 import_environment("LOG_PATH", pathlib.Path) 32 33 34 class dotdict(dict): 35 """ Dictionary which can be access through var.key instead of var["key"]. """ 36 def __getattr__(self, name: str) -> Any: 37 if name not in self: 38 raise AttributeError(f"Config key {name} not found") 39 return dotdict(self[name]) if type(self[name]) is dict else self[name] 40 __setattr__ = dict.__setitem__ 41 __delattr__ = dict.__delitem__ 42 43 44 def eval_arg(config: Dict[str, Any], arg: str) -> None: 45 """ 46 Evaluate arg in the context config, and update it. 47 48 The argument is expected to be of the form: 49 (parent.)*key(=value)? 50 If no value is provided, the key is assumed to be a boolean and True is assigned to it. 51 When passing a string argument through the shell, it must be enclosed in quote (like all python string), which usually need to be escaped. 52 """ 53 key: str 54 value: Any 55 if '=' in arg: 56 key, value = arg.split('=', maxsplit=1) 57 value = eval(value, config) 58 else: 59 key, value = arg, True 60 path: List[str] = key.split('.') 61 for component in path[:-1]: 62 config = config[component] 63 config[path[-1]] = value 64 config.pop("__builtins__", None) 65 66 67 def import_arg(config: Dict[str, Any], arg: str) -> None: 68 """ 69 Load file arg, and update config with its content. 70 71 The file is loaded in an independent context, all the variable defined in the file (even through import) are added to config, with the exception of builtins and whole modules. 72 """ 73 if arg.endswith(".py"): 74 arg = arg[:-3].replace('/', '.') 75 module: types.ModuleType = importlib.import_module(arg) 76 for key, value in vars(module).items(): 77 if key not in module.__builtins__ and not key.startswith("__") and not isinstance(value, types.ModuleType): # pytype: disable=attribute-error 78 config[key] = value 79 80 81 def parse_args() -> dotdict: 82 """ 83 Parse command line arguments and return config dictionary. 84 85 Two kind of argument are supported: 86 - When the argument starts with -- it is evaluated by the eval_arg function 87 - Otherwise the argument is assumed to be a file which is loaded by the import_arg function 88 """ 89 config: Dict[str, Any] = {} 90 config["config"] = config 91 for arg in sys.argv[1:]: 92 if arg.startswith("--"): 93 eval_arg(config, arg[2:]) 94 else: 95 import_arg(config, arg) 96 config.pop("config") 97 return dotdict(config) 98 99 100 def display_dict(output: Callable[[str], None], input: Dict[str, Any], depth: int = 0) -> None: 101 """ Display nested dictionaries in input using the provided output function. """ 102 for key, value in input.items(): 103 indent = '\t'*depth 104 output(f"{indent}{key}:") 105 if isinstance(value, dict): 106 output('\n') 107 display_dict(output, value, depth+1) 108 else: 109 output(f" {value}\n") 110 111 112 def print_dict(input: Dict[str, Any]) -> None: 113 """ Print dictionary to standard output. """ 114 display_dict(lambda x: print(x, end=""), input) 115 116 117 def log_dict(logger: logging.Logger, input: Dict[str, Any]) -> None: 118 """ Log dictionary to the provided logger. """ 119 class log: 120 buf: str = "" 121 122 def __call__(self, x: str) -> None: 123 self.buf += x 124 if self.buf.endswith('\n'): 125 logger.info(self.buf[:-1]) 126 self.buf = "" 127 display_dict(log(), input) 128 129 130 def flatten_dict(input: Dict[str, Any]) -> Dict[str, Union[bool, int, float, str]]: 131 """ 132 Replace nested dict by dot-separated keys, and cast keys to simple types. 133 134 repr() is used to cast non-base-type to str. 135 """ 136 def impl(result: Dict[str, Union[bool, int, float, str]], input: Dict[str, Any], prefix: str): 137 for key, value in input.items(): 138 if isinstance(value, dict): 139 impl(result, value, f"{key}.") 140 else: 141 result[f"{prefix}{key}"] = value if type(value) in [bool, int, float, str] else repr(value) 142 143 result: Dict[str, Union[bool, int, float, str]] = {} 144 impl(result, input, "") 145 return result 146 147 148 def get_repo_version() -> str: 149 """ Get the code repository version. """ 150 repo_dir = pathlib.Path(__file__).parents[0] 151 result: subprocess.CompletedProcess = subprocess.run( 152 ["git", "rev-parse", "HEAD"], 153 stdout=subprocess.PIPE, 154 stderr=subprocess.DEVNULL, 155 encoding="utf-8", 156 cwd=repo_dir) 157 158 if result.returncode != 0: 159 return "release" 160 commit_hash: str = result.stdout.strip()[:8] 161 162 result = subprocess.run( 163 ["git", "status", "--porcelain"], 164 stdout=subprocess.PIPE, 165 stderr=subprocess.DEVNULL, 166 encoding="utf-8", 167 cwd=repo_dir) 168 modified_flag: str = "" 169 for line in result.stdout.split('\n'): 170 if line.startswith(" M "): 171 modified_flag = "+" 172 break 173 174 return f"{commit_hash}{modified_flag}" 175 176 177 def experiment_name(name: str) -> str: 178 """ Name of the experiment (contains repository version, argument and time). """ 179 args: str = ' '.join(sys.argv[1:]) 180 version: str = get_repo_version() 181 stime: str = time.strftime("%FT%H:%M:%S") 182 return f"{name} {version} {args} {stime}" 183 184 185 def logdir_name(name: str) -> str: 186 """ Name of the experiment directory, it should be the experiment_name clipped because of filesystem constraints. """ 187 subdir: str = experiment_name(name).replace('/', '_') 188 if len(subdir) > 255: 189 sha1: str = hashlib.sha1(subdir.encode("utf-8")).digest().hex()[:16] 190 subdir = subdir[:255-17] + ' ' + sha1 191 return subdir 192 193 194 def fix_transformers_logging_handler() -> None: 195 """ The transformers package from huggingface install its own logger on import, I don't want it. """ 196 logger: logging.Logger = logging.getLogger() 197 for handler in logger.handlers: 198 logger.removeHandler(handler) 199 200 201 def add_logging_handler(logdir: pathlib.Path) -> None: 202 logfile: pathlib.Path = logdir / "log" 203 logging.basicConfig(format="%(asctime)s\t%(levelname)s:%(name)s:%(message)s", filename=logfile, filemode='a', level=logging.INFO) 204 205 206 def save_patch(outpath: pathlib.Path) -> None: 207 """ Save a file at the given patch containing the diff between the current code and the last commit. """ 208 repo_dir = pathlib.Path(__file__).parents[0] 209 210 with outpath.open("w") as outfile: 211 result: subprocess.CompletedProcess = subprocess.run( 212 ["git", "diff", "HEAD"], 213 stdout=outfile, 214 stderr=subprocess.DEVNULL, 215 encoding="utf-8", 216 cwd=repo_dir) 217 218 assert(result.returncode == 0) 219 220 221 class Experiment: 222 """ 223 Base class for running an experiment. 224 225 Calling run on an instance of this class will call the init() then main() functions. 226 The sole purpose of this class is to make an experiment "prettier": it displays config values, store a diff of the repo in the experiment logdir, etc. 227 228 Config: 229 deterministic: run in deterministic mode 230 seed: seed for random number generators 231 """ 232 233 _HAS_DYNAMIC_ATTRIBUTES = True 234 235 def __init__(self, config: dotdict, logdir: pathlib.Path, state_dicts: Optional[Dict[str, Any]] = None) -> None: 236 self.config = config 237 self.logdir = logdir 238 self.state_dicts = state_dicts 239 240 def init(self) -> None: 241 """ Prepare the experiment (e.g. initialize datasets and models). """ 242 pass 243 244 def main(self) -> NoReturn: 245 """ Run the experiment in itself. """ 246 raise NotImplementedError("Subclasses must implement a main method") 247 248 def close(self) -> None: 249 """ Free used resources (e.g. close opened files). """ 250 pass 251 252 def run(self) -> None: 253 """ Run the whole experiment with setting ups, etc. """ 254 self.log_environment() 255 self.log_patch() 256 self.initialize_rng() 257 self.init() 258 self.hook_signals() 259 self.main() 260 self.close() 261 262 def log_environment(self) -> None: 263 """ Display information about the environment. """ 264 print(f"logdir is \033[1m\033[33m{self.logdir}\033[0m") 265 self.version_check() 266 self.detect_gpus() 267 print("") 268 269 print("\033[1m\033[33mConfiguration\033[0m") 270 print_dict(self.config) 271 log_dict(logging.getLogger("config"), self.config) 272 print("") 273 274 def version_check(self) -> None: 275 """ Check the version of the main dependencies. """ 276 python_version: str = '.'.join(map(str, sys.version_info[:3])) 277 torch_version: str = torch.__version__ 278 cuda_available: str = str(torch.cuda.is_available()) 279 280 logger.info(f"python version {python_version}") 281 logger.info(f"torch version {torch_version}") 282 logger.info(f"cuda available {cuda_available}") 283 284 def problem(msg: str) -> str: 285 return f"\033[1m\033[31m{msg}\033[0m" 286 if sys.version_info < (3, 7): 287 python_version = problem(python_version) 288 if list(map(int, torch_version.split('+')[0].split('.'))) < [1, 6]: 289 torch_version = problem(torch_version) 290 if cuda_available == "False": 291 cuda_available = problem(cuda_available) 292 print(f"python version: {python_version}, torch version: {torch_version}, cuda available: {cuda_available}") 293 294 def detect_gpus(self) -> None: 295 """ Display available gpus and set self.device. """ 296 count: int = torch.cuda.device_count() 297 298 if count == 0: 299 print(f"\033[1m\033[31mNo GPU available\033[0m") 300 logger.warning("no GPU available") 301 self.device = torch.device("cpu") 302 else: 303 self.device = torch.device("cuda:0") 304 305 for i in range(count): 306 gp = torch.cuda.get_device_properties(i) 307 print(f"GPU{i}: \033[33m{gp.name}\033[0m (Mem: {gp.total_memory/2**30:.2f}GiB CC: {gp.major}.{gp.minor})") 308 logger.info(f"GPU{i} {gp.name} {gp.total_memory} {gp.major}.{gp.minor}") 309 310 def log_patch(self) -> None: 311 """ Check the version of the code, and save a patch to logpath if it was modified. """ 312 version: str = get_repo_version() 313 logger.info(f"repository_version {version}") 314 if version == "release": 315 print(f"\033[41mRelease version\033[0m\n") 316 elif version.endswith('+'): 317 print(f"\033[31mUncommited changes detected, saving patch to logdir.\033[0m\n") 318 suffix: str = "" 319 if self.state_dicts: # Reloading an existing Trainer 320 suffix = time.strftime("%FT%H:%M:%S") 321 save_patch(self.logdir / f"patch{suffix}") 322 323 def initialize_rng(self) -> None: 324 if self.state_dicts and "torch_rng" in self.state_dicts: 325 torch.random.set_rng_state(self.state_dicts["torch_rng"]) 326 assert(("cuda_rng" in self.state_dicts) == torch.cuda.is_available()) 327 if "cuda_rng" in self.state_dicts: 328 torch.cuda.random.set_rng_state_all(self.state_dicts["cuda_rng"]) 329 else: 330 torch.manual_seed(self.config.seed) 331 332 if self.config.get("deterministic") and torch.backends.cudnn.enabled: 333 torch.backends.cudnn.deterministic = True 334 torch.backends.cudnn.benchmark = False 335 336 def hook_signals(self) -> None: 337 """ Change the behavior of SIGINT (^C) to change a variable `self.interrupted' before killing the process. """ 338 self.interrupted: bool = False 339 340 def handler(sig: int, frame: types.FrameType) -> None: 341 if multiprocessing.current_process().name != "MainProcess": 342 return 343 344 print("\n\033[31mInterrupted, execution will stop at the end of this epoch.\n\033[1mNEXT ^C WILL KILL THE PROCESS!\033[0m\n", file=sys.stderr) 345 self.interrupted = True 346 signal.signal(signal.SIGINT, signal.SIG_DFL) 347 348 signal.signal(signal.SIGINT, handler) 349 350 351 class SharedLongTensorList: 352 def __init__(self, tensor_list: List[torch.Tensor], view: List[int] = [-1]): 353 self.view = view 354 355 total_element: int = 0 356 for tensor in tensor_list: 357 total_element += tensor.numel() 358 self.data: torch.Tensor = torch.empty(total_element, dtype=torch.int64) 359 self.indices: torch.Tensor = torch.empty(len(tensor_list)+1, dtype=torch.int64) 360 361 data_pos: int = total_element 362 indices_pos: int = len(tensor_list) 363 self.indices[indices_pos] = data_pos 364 while tensor_list: 365 tensor: torch.Tensor = tensor_list.pop() 366 tensor_size: int = tensor.numel() 367 368 indices_pos -= 1 369 data_pos -= tensor_size 370 371 self.indices[indices_pos] = data_pos 372 self.data[data_pos:data_pos+tensor_size] = tensor.flatten() 373 assert(data_pos == 0) 374 assert(indices_pos == 0) 375 376 def __len__(self) -> int: 377 return self.indices.shape[0]-1 378 379 def __getitem__(self, key: Union[int, slice]) -> torch.Tensor: 380 if isinstance(key, slice): 381 return [self[value] for value in range(*key.indices(len(self)))] 382 elif isinstance(key, int): 383 return self.data[self.indices[key]:self.indices[key+1]].view(*self.view) 384 elif isinstance(key, torch.Tensor): 385 return self[key.item()] 386 else: 387 raise TypeError("Invalid argument type.")