-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
115 lines (97 loc) · 3.76 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import paddle
import paddle.nn as nn
from paddle.io import DataLoader
from paddle.metric import Accuracy
from paddle.optimizer import Adam
from paddle.optimizer.lr import StepDecay
from data import ModelNetDataset
from model import CrossEntropyMatrixRegularization, PointNetClassifier
def parse_args():
parser = argparse.ArgumentParser("Train")
parser.add_argument(
"--batch_size", type=int, default=32, help="batch size in training"
)
parser.add_argument("--num_category", type=int, default=40, help="ModelNet10/40")
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="learning rate in training"
)
parser.add_argument("--num_point", type=int, default=1024, help="point number")
parser.add_argument("--max_epochs", type=int, default=200, help="max epochs")
parser.add_argument("--num_workers", type=int, default=32, help="num wrokers")
parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay")
parser.add_argument("--log_batch_num", type=int, default=50)
parser.add_argument("--model_path", type=str, default="pointnet.pdparams")
parser.add_argument("--lr_decay_step", type=int, default=20)
parser.add_argument("--lr_decay_gamma", type=float, default=0.7)
parser.add_argument(
"--data_dir", type=str, default="modelnet40_normal_resampled",
)
return parser.parse_args()
def train(args):
train_data = ModelNetDataset(args.data_dir, split="train", num_point=args.num_point)
test_data = ModelNetDataset(args.data_dir, split="test", num_point=args.num_point)
train_loader = DataLoader(
train_data,
shuffle=True,
num_workers=args.num_workers,
batch_size=args.batch_size,
)
test_loader = DataLoader(
test_data,
shuffle=False,
num_workers=args.num_workers,
batch_size=args.batch_size,
)
model = PointNetClassifier()
scheduler = StepDecay(
learning_rate=args.learning_rate,
step_size=args.lr_decay_step,
gamma=args.lr_decay_gamma,
)
optimizer = Adam(
learning_rate=scheduler,
parameters=model.parameters(),
weight_decay=args.weight_decay,
)
loss_fn = CrossEntropyMatrixRegularization()
metrics = Accuracy()
best_test_acc = 0
for epoch in range(args.max_epochs):
metrics.reset()
model.train()
for batch_id, data in enumerate(train_loader):
x, y = data
pred, trans_input, trans_feat = model(x)
loss = loss_fn(pred, y, trans_feat)
correct = metrics.compute(pred, y)
metrics.update(correct)
loss.backward()
if (batch_id + 1) % args.log_batch_num == 0:
print(
"Epoch: {}, Batch ID: {}, Loss: {}, ACC: {}".format(
epoch, batch_id + 1, loss.item(), metrics.accumulate()
)
)
optimizer.step()
optimizer.clear_grad()
scheduler.step()
metrics.reset()
model.eval()
for batch_id, data in enumerate(test_loader):
x, y = data
pred, trans_input, trans_feat = model(x)
correct = metrics.compute(pred, y)
metrics.update(correct)
test_acc = metrics.accumulate()
print("Test epoch: {}, acc is: {}".format(epoch, test_acc))
if test_acc > best_test_acc:
best_test_acc = test_acc
paddle.save(model.state_dict(), args.model_path)
print("Model saved. Best Test ACC: {}".format(test_acc))
else:
print("Model not saved. Current Best Test ACC: {}".format(best_test_acc))
if __name__ == "__main__":
args = parse_args()
print(args)
train(args)