-
Notifications
You must be signed in to change notification settings - Fork 83
/
train.py
180 lines (146 loc) · 6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from user_simulator import UserSimulator
from error_model_controller import ErrorModelController
from dqn_agent import DQNAgent
from state_tracker import StateTracker
import pickle, argparse, json, math
from utils import remove_empty_slots
from user import User
if __name__ == "__main__":
# Can provide constants file path in args OR run it as is and change 'CONSTANTS_FILE_PATH' below
# 1) In terminal: python train.py --constants_path "constants.json"
# 2) Run this file as is
parser = argparse.ArgumentParser()
parser.add_argument('--constants_path', dest='constants_path', type=str, default='')
args = parser.parse_args()
params = vars(args)
# Load constants json into dict
CONSTANTS_FILE_PATH = 'constants.json'
if len(params['constants_path']) > 0:
constants_file = params['constants_path']
else:
constants_file = CONSTANTS_FILE_PATH
with open(constants_file) as f:
constants = json.load(f)
# Load file path constants
file_path_dict = constants['db_file_paths']
DATABASE_FILE_PATH = file_path_dict['database']
DICT_FILE_PATH = file_path_dict['dict']
USER_GOALS_FILE_PATH = file_path_dict['user_goals']
# Load run constants
run_dict = constants['run']
USE_USERSIM = run_dict['usersim']
WARMUP_MEM = run_dict['warmup_mem']
NUM_EP_TRAIN = run_dict['num_ep_run']
TRAIN_FREQ = run_dict['train_freq']
MAX_ROUND_NUM = run_dict['max_round_num']
SUCCESS_RATE_THRESHOLD = run_dict['success_rate_threshold']
# Load movie DB
# Note: If you get an unpickling error here then run 'pickle_converter.py' and it should fix it
database = pickle.load(open(DATABASE_FILE_PATH, 'rb'), encoding='latin1')
# Clean DB
remove_empty_slots(database)
# Load movie dict
db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')
# Load goal File
user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'), encoding='latin1')
# Init. Objects
if USE_USERSIM:
user = UserSimulator(user_goals, constants, database)
else:
user = User(constants)
emc = ErrorModelController(db_dict, constants)
state_tracker = StateTracker(database, constants)
dqn_agent = DQNAgent(state_tracker.get_state_size(), constants)
def run_round(state, warmup=False):
# 1) Agent takes action given state tracker's representation of dialogue (state)
agent_action_index, agent_action = dqn_agent.get_action(state, use_rule=warmup)
# 2) Update state tracker with the agent's action
state_tracker.update_state_agent(agent_action)
# 3) User takes action given agent action
user_action, reward, done, success = user.step(agent_action)
if not done:
# 4) Infuse error into semantic frame level of user action
emc.infuse_error(user_action)
# 5) Update state tracker with user action
state_tracker.update_state_user(user_action)
# 6) Get next state and add experience
next_state = state_tracker.get_state(done)
dqn_agent.add_experience(state, agent_action_index, reward, next_state, done)
return next_state, reward, done, success
def warmup_run():
"""
Runs the warmup stage of training which is used to fill the agents memory.
The agent uses it's rule-based policy to make actions. The agent's memory is filled as this runs.
Loop terminates when the size of the memory is equal to WARMUP_MEM or when the memory buffer is full.
"""
print('Warmup Started...')
total_step = 0
while total_step != WARMUP_MEM and not dqn_agent.is_memory_full():
# Reset episode
episode_reset()
done = False
# Get initial state from state tracker
state = state_tracker.get_state()
while not done:
next_state, _, done, _ = run_round(state, warmup=True)
total_step += 1
state = next_state
print('...Warmup Ended')
def train_run():
"""
Runs the loop that trains the agent.
Trains the agent on the goal-oriented chatbot task. Training of the agent's neural network occurs every episode that
TRAIN_FREQ is a multiple of. Terminates when the episode reaches NUM_EP_TRAIN.
"""
print('Training Started...')
episode = 0
period_reward_total = 0
period_success_total = 0
success_rate_best = 0.0
while episode < NUM_EP_TRAIN:
episode_reset()
episode += 1
done = False
state = state_tracker.get_state()
while not done:
next_state, reward, done, success = run_round(state)
period_reward_total += reward
state = next_state
period_success_total += success
# Train
if episode % TRAIN_FREQ == 0:
# Check success rate
success_rate = period_success_total / TRAIN_FREQ
avg_reward = period_reward_total / TRAIN_FREQ
# Flush
if success_rate >= success_rate_best and success_rate >= SUCCESS_RATE_THRESHOLD:
dqn_agent.empty_memory()
# Update current best success rate
if success_rate > success_rate_best:
print('Episode: {} NEW BEST SUCCESS RATE: {} Avg Reward: {}' .format(episode, success_rate, avg_reward))
success_rate_best = success_rate
dqn_agent.save_weights()
period_success_total = 0
period_reward_total = 0
# Copy
dqn_agent.copy()
# Train
dqn_agent.train()
print('...Training Ended')
def episode_reset():
"""
Resets the episode/conversation in the warmup and training loops.
Called in warmup and train to reset the state tracker, user and agent. Also get's the initial user action.
"""
# First reset the state tracker
state_tracker.reset()
# Then pick an init user action
user_action = user.reset()
# Infuse with error
emc.infuse_error(user_action)
# And update state tracker
state_tracker.update_state_user(user_action)
# Finally, reset agent
dqn_agent.reset()
warmup_run()
train_run()