-
Notifications
You must be signed in to change notification settings - Fork 1
/
image_recognition.py
82 lines (59 loc) · 1.84 KB
/
image_recognition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# -*- coding: utf-8 -*-
"""
Created on Thu May 4 22:30:42 2017
@author: Matt Green
"""
from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm
import problem_unittests as tests
import tarfile
import helper
import numpy as np
from sklearn import preprocessing
cifar10_dataset_folder_path = 'cifar-10-batches-py'
class DLProgress(tqdm):
last_block = 0
def hook(self, block_num=1, block_size=1, total_size=None):
self.total = total_size
self.update((block_num - self.last_block) * block_size)
self.last_block = block_num
if not isfile('cifar-10-python.tar.gz'):
with DLProgress(unit='B', unit_scale=True, miniters=1, desc='CIFAR-10 Dataset') as pbar:
urlretrieve(
'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz',
'cifar-10-python.tar.gz',
pbar.hook)
if not isdir(cifar10_dataset_folder_path):
with tarfile.open('cifar-10-python.tar.gz') as tar:
tar.extractall()
tar.close()
tests.test_folder_path(cifar10_dataset_folder_path)
# %%
# Explore the dataset
batch_id = 1
sample_id = 5
helper.display_stats(cifar10_dataset_folder_path, batch_id, sample_id)
# %%
def normalize(x):
w = 0
pixels = np.ndarray((len(x), 32, 32, 3))
for p in x:
p = p.flatten()
p = abs((p - 128.) / 128.)
p = p.reshape(1, 32, 32, 3)
pixels[w, :, :, :] = p
w += 1
return pixels
tests.test_normalize(normalize)
# %%
def one_hot_encode(x):
classes = list(range(10))
lb = preprocessing.LabelBinarizer()
lb.fit(classes)
return lb.transform(x)
tests.test_one_hot_encode(one_hot_encode)
# %%
# Preprocess Training, Validation, and Testing Data
print("Preprocessing and saving data...")
helper.preprocess_and_save_data("cifar-10-batches-py", normalize, one_hot_encode)