-
Notifications
You must be signed in to change notification settings - Fork 6
/
all_demo.py
127 lines (95 loc) · 3.79 KB
/
all_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from tensorflow.python.eager.context import Context
from models.text2palette import T2PGenerator
from models.colorization import ColorizationGenerator
from models.context_embedding import ContextEmbedding
import tensorflow as tf
import numpy as np
import io
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import os
import csv
class AllDemo:
def __init__(self, t2p_checkpoint_path, color_checkpoint_path):
self.t2p = T2PGenerator()
self.t2p_ctx = ContextEmbedding()
self.t2p_ckpt = tf.train.Checkpoint(
gen=self.t2p,
ctx=self.t2p_ctx
)
self.t2p_ckpt.restore(t2p_checkpoint_path)
self.color_gen = ColorizationGenerator()
self.color_ctx = ContextEmbedding()
self.color_ckpt = tf.train.Checkpoint(
gen=self.color_gen,
ctx=self.color_ctx
)
self.color_ckpt.restore(color_checkpoint_path)
def generate(self, image, text, category):
assert type(image) is str
image_path = image
palette = [0.] * 15
palette_hex = []
image = tf.keras.preprocessing.image.load_img(image)
image = tf.keras.preprocessing.image.img_to_array(image)
for _ in range(5):
z = np.random.normal(0., 1., size=(1, 128)).astype(np.float32)
y = self.t2p_ctx([image], [text], [category], np.array([palette]))
new_color = self.t2p(z, y)[0]
r = int(new_color[0] * 255)
g = int(new_color[1] * 255)
b = int(new_color[2] * 255)
palette_hex.append('#{:02X}{:02X}{:02X}'.format(r, g, b))
for _ in range(3):
palette = np.array(list(palette[3:]) + list(new_color.numpy()))
color_map = sns.color_palette(palette=palette_hex, as_cmap=True)
y = self.color_ctx([image], [text], [category], np.array([palette]))
image = tf.image.resize_with_pad(image, 256, 256)
image = tf.image.rgb_to_grayscale(image)
image = image / 127.5 - 1
generated_image = self.color_gen([np.array([image]), y])
plt.figure()
plt.subplot(1, 4, 1)
plt.title('input image')
plt.axis('off')
plt.imshow(Image.open(image_path).convert('L'), cmap='gray')
plt.subplot(1, 4, 2)
plt.title('generated palette')
plt.axis('off')
data = np.array(range(5)).reshape((1, 5))
sns.heatmap(data, cmap=color_map, cbar=False)
plt.subplot(1, 4, 3)
plt.title('colored image')
plt.axis('off')
plt.imshow(generated_image.numpy()[0] * 0.5 + 0.5)
plt.subplot(1, 4, 4)
plt.title('ground truth')
plt.axis('off')
plt.imshow(Image.open(image_path))
buf = io.BytesIO()
plt.savefig(buf)
buf.seek(0)
plt.close()
return Image.open(buf), palette_hex
if __name__ == '__main__':
demo = AllDemo(t2p_checkpoint_path='./ckpt-t2p/210129/ckpt-70',
color_checkpoint_path='./ckpt-color/210129-3/ckpt-100')
img_dir = './data/images'
output_dir = './output/whole-pipeline-test'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
raw_data = './data/preprocessed_data.csv'
raw_reader = csv.reader(open(raw_data, 'r'))
for idx, item in enumerate(raw_reader):
if idx == 0:
continue
image_path = item[0].strip()
text = item[1].strip()
category = item[2].strip()
for generated_idx in range(5):
generated_img, palette_hex = demo.generate(
os.path.join(img_dir, image_path), text, category)
generated_img.save(os.path.join(
output_dir, '%d-%d-%s.jpg' % (idx, generated_idx, ''.join(palette_hex))))
print('%d %d' % (idx, generated_idx))