diff --git a/open_spiel/python/algorithms/efr.py b/open_spiel/python/algorithms/efr.py index 1976d33ed3..e5545d8494 100644 --- a/open_spiel/python/algorithms/efr.py +++ b/open_spiel/python/algorithms/efr.py @@ -121,7 +121,7 @@ def __init__(self, game, deviation_gen): def return_cumulative_regret(self): """Returns a dictionary mapping. - The mapping is fromevery information state to its associated regret + The mapping is from every information state to its associated regret (accumulated over all iterations). """ return { @@ -275,7 +275,7 @@ def _update_current_policy(self, state, current_policy): state_policy = current_policy.policy_for_key(info_state) for action, value in self._regret_matching( - info_state_node.legal_actions, info_state_node + info_state_node ).items(): state_policy[action] = value @@ -491,48 +491,48 @@ def __init__(self, game, deviations_name): deviation_sets = return_behavourial else: raise ValueError( - "Unsupported Deviation Set Passed As " - " Constructor Argument" + "Unsupported Deviation Set Passed\ + As Constructor Argument" ) super(EFRSolver, self).__init__(game, deviation_sets) self._external_only = external_only - def _regret_matching(self, legal_actions, info_set_node): + def _regret_matching(self, info_set_node): """Returns an info state policy. The info state policy returned is the one obtained by applying regret-matching function over all deviations and time selection functions. Args: - legal_actions: the list of legal actions at this state. info_set_node: the info state node to compute the policy for. Returns: - A dict of action -> prob for all legal actions. + A dict of action -> prob for all legal actions of the + info_set_node. """ + legal_actions = info_set_node.legal_actions + num_actions = len(legal_actions) + info_state_policy = None z = sum(info_set_node.y_values.values()) - info_state_policy = {} # The fixed point solution can be directly obtained through the # weighted regret matrix if only external deviations are used. if self._external_only and z > 0: weighted_deviation_matrix = np.zeros( - (len(legal_actions), len(legal_actions)) + (num_actions, num_actions) ) for dev in list(info_set_node.y_values.keys()): weighted_deviation_matrix += ( info_set_node.y_values[dev] / z ) * dev.return_transform_matrix() new_strategy = weighted_deviation_matrix[:, 0] - for index in range(len(legal_actions)): - info_state_policy[legal_actions[index]] = new_strategy[index] + info_state_policy = dict(zip(legal_actions, new_strategy)) # Full regret matching by finding the least squares solution to the # fixed point of the EFR regret matching function. # Last row of matrix and the column entry minimises the solution # towards a strategy. elif z > 0: - num_actions = len(info_set_node.legal_actions) weighted_deviation_matrix = -np.eye(num_actions) for dev in list(info_set_node.y_values.keys()): @@ -551,19 +551,17 @@ def _regret_matching(self, legal_actions, info_set_node): strategy = linalg.lstsq(weighted_deviation_matrix, b)[0] # Adopt same clipping strategy as paper author's code. - strategy[np.where(strategy < 0)] = 0 - strategy[np.where(strategy > 1)] = 1 + np.clip(strategy, a_min=0, a_max=1, out=strategy) + strategy = strategy / np.sum(strategy) - strategy = strategy / sum(strategy) - for index in range(len(strategy)): - info_state_policy[info_set_node.legal_actions[index]] = strategy[index] + info_state_policy = dict(zip(legal_actions, strategy[:,0])) # Use a uniform strategy as sum of all regrets is negative. else: - for index in range(len(legal_actions)): - info_state_policy[legal_actions[index]] = 1.0 / len(legal_actions) + unif_policy_value = 1.0 / num_actions + info_state_policy = {legal_actions[index]:unif_policy_value + for index in range(num_actions)} return info_state_policy - def _update_average_policy(average_policy, info_state_nodes): """Updates in place `average_policy` to the average of all policies iterated. @@ -617,10 +615,7 @@ def array_to_strat_dict(strategy_array, legal_actions): Returns: strategy_dictionary: a dictionary action -> prob value. """ - strategy_dictionary = {} - for action in legal_actions: - strategy_dictionary[action] = strategy_array[action] - return strategy_dictionary + return dict(zip(legal_actions, strategy_array)) def create_probs_from_index(indices, current_policy):