-
Notifications
You must be signed in to change notification settings - Fork 83
/
test.py
116 lines (95 loc) · 3.87 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from user_simulator import UserSimulator
from error_model_controller import ErrorModelController
from dqn_agent import DQNAgent
from state_tracker import StateTracker
import pickle, argparse, json
from user import User
from utils import remove_empty_slots
if __name__ == "__main__":
# Can provide constants file path in args OR run it as is and change 'CONSTANTS_FILE_PATH' below
# 1) In terminal: python test.py --constants_path "constants.json"
# 2) Run this file as is
parser = argparse.ArgumentParser()
parser.add_argument('--constants_path', dest='constants_path', type=str, default='')
args = parser.parse_args()
params = vars(args)
# Load constants json into dict
CONSTANTS_FILE_PATH = 'constants.json'
if len(params['constants_path']) > 0:
constants_file = params['constants_path']
else:
constants_file = CONSTANTS_FILE_PATH
with open(constants_file) as f:
constants = json.load(f)
# Load file path constants
file_path_dict = constants['db_file_paths']
DATABASE_FILE_PATH = file_path_dict['database']
DICT_FILE_PATH = file_path_dict['dict']
USER_GOALS_FILE_PATH = file_path_dict['user_goals']
# Load run constants
run_dict = constants['run']
USE_USERSIM = run_dict['usersim']
NUM_EP_TEST = run_dict['num_ep_run']
MAX_ROUND_NUM = run_dict['max_round_num']
# Load movie DB
# Note: If you get an unpickling error here then run 'pickle_converter.py' and it should fix it
database = pickle.load(open(DATABASE_FILE_PATH, 'rb'), encoding='latin1')
# Clean DB
remove_empty_slots(database)
# Load movie dict
db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')
# Load goal file
user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'), encoding='latin1')
# Init. Objects
if USE_USERSIM:
user = UserSimulator(user_goals, constants, database)
else:
user = User(constants)
emc = ErrorModelController(db_dict, constants)
state_tracker = StateTracker(database, constants)
dqn_agent = DQNAgent(state_tracker.get_state_size(), constants)
def test_run():
"""
Runs the loop that tests the agent.
Tests the agent on the goal-oriented chatbot task. Only for evaluating a trained agent. Terminates when the episode
reaches NUM_EP_TEST.
"""
print('Testing Started...')
episode = 0
while episode < NUM_EP_TEST:
episode_reset()
episode += 1
ep_reward = 0
done = False
# Get initial state from state tracker
state = state_tracker.get_state()
while not done:
# Agent takes action given state tracker's representation of dialogue
agent_action_index, agent_action = dqn_agent.get_action(state)
# Update state tracker with the agent's action
state_tracker.update_state_agent(agent_action)
# User takes action given agent action
user_action, reward, done, success = user.step(agent_action)
ep_reward += reward
if not done:
# Infuse error into semantic frame level of user action
emc.infuse_error(user_action)
# Update state tracker with user action
state_tracker.update_state_user(user_action)
# Grab "next state" as state
state = state_tracker.get_state(done)
print('Episode: {} Success: {} Reward: {}'.format(episode, success, ep_reward))
print('...Testing Ended')
def episode_reset():
"""Resets the episode/conversation in the testing loop."""
# First reset the state tracker
state_tracker.reset()
# Then pick an init user action
user_action = user.reset()
# Infuse with error
emc.infuse_error(user_action)
# And update state tracker
state_tracker.update_state_user(user_action)
# Finally, reset agent
dqn_agent.reset()
test_run()