-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
92 lines (73 loc) · 3.36 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import importlib
import torch
import utils
if __name__ == '__main__' :
exam_code = '''
e.g)
python train.py --dt pf --model pf_bicubic
'''
parser = argparse.ArgumentParser("Train Mask R-CNN model",epilog=exam_code)
# setting
parser.add_argument('-d' ,'--dt' ,default='pf' ,metavar='{pf,bln}' , help='Dataset')
parser.add_argument('-m' ,'--model' ,default='bicubic' ,metavar='{...}' ,help='model class name')
parser.add_argument('-pre','--pretrain',default='none' ,metavar='{...}' ,help='pretrained model file path')
parser.add_argument('-o' ,'--out' ,default= './model/new.pth' ,help='model path to save')
#hyper param
parser.add_argument('-epo','--epochs' ,default= 50 ,type = int ,help='number of epochs')
parser.add_argument('-b' ,'--batch' ,default= 4 ,type = int ,help='batch size')
parser.add_argument('-w' ,'--workers' ,default= 4 ,type = int ,help='number of batch workers')
parser.add_argument('-de' ,'--device' ,default= 'cuda:0' ,help='device e.g} cuda:0')
args = parser.parse_args()
# args edit
args.dt = args.dt.lower()
if args.dt == 'pf' or args.dt == 'pennfudan':
args.dt = 'PennFudan'
num_classes = 2
elif args.dt == 'bln':
args.dt = 'balloon'
num_classes = 2
if args.out == './model/new.pth':
args.out = f'./model/{args.dt}_{args.model}.pth'
import pprint
pprint.pprint(args)
###### changeable main
from datasets import get_dataset
from models import get_models
dataset , dataset_test = get_dataset(args.dt)
model = get_models(name = args.model,num_classes = num_classes)
######
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch, shuffle=True, num_workers=args.workers,
collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, shuffle=False, num_workers=args.workers,
collate_fn=utils.collate_fn)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
device = args.device
model.to(device)
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=3,
gamma=0.1)
num_epochs = args.epochs
from engine import train_one_epoch , evaluate
evaluators = []
for epoch in range(num_epochs):
# train for one epoch, printing every 10 iterations
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
# update the learning rate
lr_scheduler.step()
# evaluate on the test dataset
# device = 'cuda:1'
# model.to(device)
evaluators.append( evaluate(model, data_loader_test, device=device) )
torch.save({'state_dict':model.state_dict(),
'evaluators':evaluators
},args.out)