-
Notifications
You must be signed in to change notification settings - Fork 8
/
utils.py
67 lines (50 loc) · 1.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import json
import os
import uuid
import numpy as np
import requests
import torch
from PIL import Image
from munch import Munch
def load_cfg(cfg_path="config.json"):
assert os.path.exists(cfg_path), "config.json is missing!"
with open(cfg_path, 'rb') as f:
cfg = json.load(f)
cfg = Munch(cfg)
return cfg
cache_path = 'cache'
data_path = './data'
def load_weights(file_name, download_url, device):
os.makedirs(data_path, exist_ok=True)
weight_path = os.path.join(data_path, file_name)
if not os.path.exists(weight_path):
print(f"Downloading from: {download_url}...")
res = requests.get(download_url)
with open(weight_path, 'wb') as f:
f.write(res.content)
print(f'File saved at: {weight_path}')
return torch.load(weight_path, map_location=torch.device(device))
def set_eval_mode(nets):
for net in nets.values():
net.eval()
def to_device(nets, device):
for net in nets.values():
net.to(device)
def denormalize(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
def get_image_name():
return f"{uuid.uuid4().hex}.png"
def save_images(imgs):
imgs = denormalize(imgs)
imgs = imgs * 255
imgs = imgs.cpu().numpy().astype(np.uint8)
imgs = imgs.transpose((0, 2, 3, 1))
os.makedirs(cache_path, exist_ok=True)
filenames = []
for img in imgs:
img = Image.fromarray(img)
filename = os.path.join(cache_path, get_image_name()).replace("\\", "/")
img.save(filename)
filenames.append(filename)
return filenames