-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_cnns.py
101 lines (77 loc) · 2.97 KB
/
train_cnns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import numpy as np
from Solver import CNNSolver
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from maze_dataset import MazeDatasetSnapshots
from experiment_conf import MAZE, DATASET, CNNS, GLOBAL
def train(net, train_loader, epoch_num, loss_fn, optim):
net.train()
total = 0
correct = 0
for x, y, _, _ in train_loader:
x = x.unsqueeze(1).to(DEVICE)
y = y.to(DEVICE)
optim.zero_grad()
pred = net(x)
loss = loss_fn(pred,y)
loss.backward()
optim.step()
_, predicted = torch.max(pred, 1)
total += y.shape[0]
correct += (predicted == y).sum().item()
acc = 100 * correct / total
print(f"Epoch: {epoch_num:03d}\t loss: {loss.item():5.5f}\t acc: {acc:6.3f} %", end="\r")
print()
return loss, acc
@torch.no_grad()
def evaluate(net, eval_loader, epoch_num):
net.eval()
total = 0
correct = 0
for x, y, _, _ in eval_loader:
x = x.unsqueeze(1).to(DEVICE)
y = y.to(DEVICE)
pred = net(x)
_, predicted = torch.max(pred, 1)
total += y.shape[0]
correct += (predicted == y).sum().item()
acc = 100 * correct / total
print(f"Epoch: {epoch_num:03d}\t loss: {np.nan:5.5f}\t acc: {acc:6.3f} %")
return acc
GRID_SIZE = MAZE['GRID_SIZE']
MAX_PATH_LENGTH = MAZE['MAX_PATH_LENGTH']
SHORTEST_PATH = MAZE['SHORTEST_PATH']
NUM_TRAIN = DATASET['NUM_TRAIN']
NUM_EVAL = DATASET['NUM_EVAL']
BATCH_SIZE = CNNS['BATCH_SIZE']
MAX_EPOCH = CNNS['MAX_EPOCH']
D_MODEL = CNNS['D_MODEL']
FC_DIM = CNNS['FC_DIM']
LR = CNNS['LR']
WD = CNNS['WD']
DEVICE = GLOBAL['DEVICE']
SEED = GLOBAL['SEED']
if __name__ == "__main__":
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
train_dataset = MazeDatasetSnapshots(NUM_TRAIN, GRID_SIZE, MAX_PATH_LENGTH, SHORTEST_PATH)
eval_dataset = MazeDatasetSnapshots(NUM_EVAL, GRID_SIZE, MAX_PATH_LENGTH, SHORTEST_PATH)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
solver = CNNSolver(d_model=D_MODEL, dim_feedforward=FC_DIM).double()
pytorch_total_params = sum(p.numel() for p in solver.parameters() if p.requires_grad)
print(f"Number of parameters: {pytorch_total_params}")
solver = solver.to(DEVICE)
optim = Adam(solver.parameters(), lr=LR, weight_decay=WD)
loss_fn = nn.CrossEntropyLoss()
best_acc = 0
for epoch in range(MAX_EPOCH):
train_loss, train_acc = train(solver, train_loader, epoch, loss_fn, optim)
eval_acc = evaluate(solver, eval_loader, epoch)
if eval_acc >= best_acc:
if epoch > 100:
torch.save(solver.state_dict(), f"./cnn_solver_best.pt")
best_acc = eval_acc