-
Notifications
You must be signed in to change notification settings - Fork 83
/
state_tracker.py
193 lines (150 loc) · 8.45 KB
/
state_tracker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from db_query import DBQuery
import numpy as np
from utils import convert_list_to_dict
from dialogue_config import all_intents, all_slots, usersim_default_key
import copy
class StateTracker:
"""Tracks the state of the episode/conversation and prepares the state representation for the agent."""
def __init__(self, database, constants):
"""
The constructor of StateTracker.
The constructor of StateTracker which creates a DB query object, creates necessary state rep. dicts, etc. and
calls reset.
Parameters:
database (dict): The database with format dict(long: dict)
constants (dict): Loaded constants in dict
"""
self.db_helper = DBQuery(database)
self.match_key = usersim_default_key
self.intents_dict = convert_list_to_dict(all_intents)
self.num_intents = len(all_intents)
self.slots_dict = convert_list_to_dict(all_slots)
self.num_slots = len(all_slots)
self.max_round_num = constants['run']['max_round_num']
self.none_state = np.zeros(self.get_state_size())
self.reset()
def get_state_size(self):
"""Returns the state size of the state representation used by the agent."""
return 2 * self.num_intents + 7 * self.num_slots + 3 + self.max_round_num
def reset(self):
"""Resets current_informs, history and round_num."""
self.current_informs = {}
# A list of the dialogues (dicts) by the agent and user so far in the conversation
self.history = []
self.round_num = 0
def print_history(self):
"""Helper function if you want to see the current history action by action."""
for action in self.history:
print(action)
def get_state(self, done=False):
"""
Returns the state representation as a numpy array which is fed into the agent's neural network.
The state representation contains useful information for the agent about the current state of the conversation.
Processes by the agent to be fed into the neural network. Ripe for experimentation and optimization.
Parameters:
done (bool): Indicates whether this is the last dialogue in the episode/conversation. Default: False
Returns:
numpy.array: A numpy array of shape (state size,)
"""
# If done then fill state with zeros
if done:
return self.none_state
user_action = self.history[-1]
db_results_dict = self.db_helper.get_db_results_for_slots(self.current_informs)
last_agent_action = self.history[-2] if len(self.history) > 1 else None
# Create one-hot of intents to represent the current user action
user_act_rep = np.zeros((self.num_intents,))
user_act_rep[self.intents_dict[user_action['intent']]] = 1.0
# Create bag of inform slots representation to represent the current user action
user_inform_slots_rep = np.zeros((self.num_slots,))
for key in user_action['inform_slots'].keys():
user_inform_slots_rep[self.slots_dict[key]] = 1.0
# Create bag of request slots representation to represent the current user action
user_request_slots_rep = np.zeros((self.num_slots,))
for key in user_action['request_slots'].keys():
user_request_slots_rep[self.slots_dict[key]] = 1.0
# Create bag of filled_in slots based on the current_slots
current_slots_rep = np.zeros((self.num_slots,))
for key in self.current_informs:
current_slots_rep[self.slots_dict[key]] = 1.0
# Encode last agent intent
agent_act_rep = np.zeros((self.num_intents,))
if last_agent_action:
agent_act_rep[self.intents_dict[last_agent_action['intent']]] = 1.0
# Encode last agent inform slots
agent_inform_slots_rep = np.zeros((self.num_slots,))
if last_agent_action:
for key in last_agent_action['inform_slots'].keys():
agent_inform_slots_rep[self.slots_dict[key]] = 1.0
# Encode last agent request slots
agent_request_slots_rep = np.zeros((self.num_slots,))
if last_agent_action:
for key in last_agent_action['request_slots'].keys():
agent_request_slots_rep[self.slots_dict[key]] = 1.0
# Value representation of the round num
turn_rep = np.zeros((1,)) + self.round_num / 5.
# One-hot representation of the round num
turn_onehot_rep = np.zeros((self.max_round_num,))
turn_onehot_rep[self.round_num - 1] = 1.0
# Representation of DB query results (scaled counts)
kb_count_rep = np.zeros((self.num_slots + 1,)) + db_results_dict['matching_all_constraints'] / 100.
for key in db_results_dict.keys():
if key in self.slots_dict:
kb_count_rep[self.slots_dict[key]] = db_results_dict[key] / 100.
# Representation of DB query results (binary)
kb_binary_rep = np.zeros((self.num_slots + 1,)) + np.sum(db_results_dict['matching_all_constraints'] > 0.)
for key in db_results_dict.keys():
if key in self.slots_dict:
kb_binary_rep[self.slots_dict[key]] = np.sum(db_results_dict[key] > 0.)
state_representation = np.hstack(
[user_act_rep, user_inform_slots_rep, user_request_slots_rep, agent_act_rep, agent_inform_slots_rep,
agent_request_slots_rep, current_slots_rep, turn_rep, turn_onehot_rep, kb_binary_rep,
kb_count_rep]).flatten()
return state_representation
def update_state_agent(self, agent_action):
"""
Updates the dialogue history with the agent's action and augments the agent's action.
Takes an agent action and updates the history. Also augments the agent_action param with query information and
any other necessary information.
Parameters:
agent_action (dict): The agent action of format dict('intent': string, 'inform_slots': dict,
'request_slots': dict) and changed to dict('intent': '', 'inform_slots': {},
'request_slots': {}, 'round': int, 'speaker': 'Agent')
"""
if agent_action['intent'] == 'inform':
assert agent_action['inform_slots']
inform_slots = self.db_helper.fill_inform_slot(agent_action['inform_slots'], self.current_informs)
agent_action['inform_slots'] = inform_slots
assert agent_action['inform_slots']
key, value = list(agent_action['inform_slots'].items())[0] # Only one
assert key != 'match_found'
assert value != 'PLACEHOLDER', 'KEY: {}'.format(key)
self.current_informs[key] = value
# If intent is match_found then fill the action informs with the matches informs (if there is a match)
elif agent_action['intent'] == 'match_found':
assert not agent_action['inform_slots'], 'Cannot inform and have intent of match found!'
db_results = self.db_helper.get_db_results(self.current_informs)
if db_results:
# Arbitrarily pick the first value of the dict
key, value = list(db_results.items())[0]
agent_action['inform_slots'] = copy.deepcopy(value)
agent_action['inform_slots'][self.match_key] = str(key)
else:
agent_action['inform_slots'][self.match_key] = 'no match available'
self.current_informs[self.match_key] = agent_action['inform_slots'][self.match_key]
agent_action.update({'round': self.round_num, 'speaker': 'Agent'})
self.history.append(agent_action)
def update_state_user(self, user_action):
"""
Updates the dialogue history with the user's action and augments the user's action.
Takes a user action and updates the history. Also augments the user_action param with necessary information.
Parameters:
user_action (dict): The user action of format dict('intent': string, 'inform_slots': dict,
'request_slots': dict) and changed to dict('intent': '', 'inform_slots': {},
'request_slots': {}, 'round': int, 'speaker': 'User')
"""
for key, value in user_action['inform_slots'].items():
self.current_informs[key] = value
user_action.update({'round': self.round_num, 'speaker': 'User'})
self.history.append(user_action)
self.round_num += 1