-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
54 lines (43 loc) · 1.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import io
import pickle
import random
import logging
import mlflow
import numpy as np
import torch
from hydra import initialize, compose
from omegaconf import DictConfig, ListConfig
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
def set_seed(seed: int):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
def instantiate_config(config_path: str):
config_dir = config_path.rsplit("/", 1)[0]
config_file = config_path.rsplit("/", 1)[-1]
with initialize(config_path=config_dir, job_name=config_path):
cfg = compose(config_name=config_file)
return cfg
def log_params_from_omegaconf_dict(params):
for param_name, element in params.items():
_explore_recursive(param_name, element)
def _explore_recursive(parent_name, element):
if isinstance(element, DictConfig):
for k, v in element.items():
if isinstance(v, DictConfig) or isinstance(v, ListConfig):
_explore_recursive(f'{parent_name}.{k}', v)
else:
mlflow.log_param(f'{parent_name}.{k}', v)
elif isinstance(element, ListConfig):
for i, v in enumerate(element):
mlflow.log_param(f'{parent_name}.{i}', v)
elif isinstance(element, int) or isinstance(element, str):
mlflow.log_param(parent_name, element)
else:
logger.warning(f"Configuration field {parent_name} with value {element} not logged in mlflow.")
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
else: return super().find_class(module, name)