-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
32 lines (24 loc) · 967 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from multi_probe_lsh import MultiProbeLSH
import torchvision
import matplotlib.pyplot as plt
if __name__ == '__main__':
dim = 3072
l = 2
m = 3
w = 128
lsh = MultiProbeLSH(dim, l, m, w)
train_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=True, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=False, download=True)
for i, data in enumerate(train_dataset.data):
pic_vec = train_dataset.data[i].ravel()
lsh.insert(pic_vec, str(i))
for test_aim in range(100):
query = test_dataset.data[test_aim].ravel()
res = lsh.query(query)
print(res)
plt.axis('off')
plt.imshow(test_dataset.data[test_aim])
plt.savefig('./data/lsh/test_' + str(test_aim) + '.png')
for i in res:
plt.imshow(train_dataset.data[int(i)])
plt.savefig('./data/lsh/test_' + str(test_aim) + '_' + i + '.png')