-
Notifications
You must be signed in to change notification settings - Fork 21
/
ds_generator.py
47 lines (34 loc) · 996 Bytes
/
ds_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import cv2
import tensorflow as tf
label_map = {'猫': 0, '狗': 1}
def gen():
with open('train.csv') as f:
lines = [line.strip().split(',') for line in f.readlines()]
index = 0
while True:
image = cv2.imread(lines[index][0])
image = cv2.resize(image, (224, 224))
label = label_map[lines[index][1]]
yield (image, label)
index += 1
if index == len(lines):
index = 0
def create_dataset():
data = tf.data.Dataset.from_generator(gen, (tf.float32, tf.int32),
(tf.TensorShape([224, 224, 3]), tf.TensorShape([])))
data = data.batch(2)
data = data.make_one_shot_iterator()
tt = time.time()
_, labels = data.get_next()
with tf.Session() as sess:
for i in range(100):
getlabels = sess.run(labels)
print('{} -> {}'.format(i, getlabels))
print(time.time() - tt)
def main():
create_dataset()
if __name__ == '__main__':
main()