Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed NumPy deprication warning in efr.py #1271

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 19 additions & 24 deletions open_spiel/python/algorithms/efr.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self, game, deviation_gen):
def return_cumulative_regret(self):
"""Returns a dictionary mapping.

The mapping is fromevery information state to its associated regret
The mapping is from every information state to its associated regret
(accumulated over all iterations).
"""
return {
Expand Down Expand Up @@ -275,7 +275,7 @@ def _update_current_policy(self, state, current_policy):

state_policy = current_policy.policy_for_key(info_state)
for action, value in self._regret_matching(
info_state_node.legal_actions, info_state_node
info_state_node
).items():
state_policy[action] = value

Expand Down Expand Up @@ -491,48 +491,48 @@ def __init__(self, game, deviations_name):
deviation_sets = return_behavourial
else:
raise ValueError(
"Unsupported Deviation Set Passed As "
" Constructor Argument"
"Unsupported Deviation Set Passed\
As Constructor Argument"
)
super(EFRSolver, self).__init__(game, deviation_sets)
self._external_only = external_only

def _regret_matching(self, legal_actions, info_set_node):
def _regret_matching(self, info_set_node):
"""Returns an info state policy.

The info state policy returned is the one obtained by applying
regret-matching function over all deviations and time selection functions.

Args:
legal_actions: the list of legal actions at this state.
info_set_node: the info state node to compute the policy for.

Returns:
A dict of action -> prob for all legal actions.
A dict of action -> prob for all legal actions of the
info_set_node.
"""
legal_actions = info_set_node.legal_actions
num_actions = len(legal_actions)
info_state_policy = None
z = sum(info_set_node.y_values.values())
info_state_policy = {}

# The fixed point solution can be directly obtained through the
# weighted regret matrix if only external deviations are used.
if self._external_only and z > 0:
weighted_deviation_matrix = np.zeros(
(len(legal_actions), len(legal_actions))
(num_actions, num_actions)
)
for dev in list(info_set_node.y_values.keys()):
weighted_deviation_matrix += (
info_set_node.y_values[dev] / z
) * dev.return_transform_matrix()
new_strategy = weighted_deviation_matrix[:, 0]
for index in range(len(legal_actions)):
info_state_policy[legal_actions[index]] = new_strategy[index]
info_state_policy = dict(zip(legal_actions, new_strategy))

# Full regret matching by finding the least squares solution to the
# fixed point of the EFR regret matching function.
# Last row of matrix and the column entry minimises the solution
# towards a strategy.
elif z > 0:
num_actions = len(info_set_node.legal_actions)
weighted_deviation_matrix = -np.eye(num_actions)

for dev in list(info_set_node.y_values.keys()):
Expand All @@ -551,19 +551,17 @@ def _regret_matching(self, legal_actions, info_set_node):
strategy = linalg.lstsq(weighted_deviation_matrix, b)[0]

# Adopt same clipping strategy as paper author's code.
strategy[np.where(strategy < 0)] = 0
strategy[np.where(strategy > 1)] = 1
np.clip(strategy, a_min=0, a_max=1, out=strategy)
strategy = strategy / np.sum(strategy)

strategy = strategy / sum(strategy)
for index in range(len(strategy)):
info_state_policy[info_set_node.legal_actions[index]] = strategy[index]
info_state_policy = dict(zip(legal_actions, strategy[:,0]))
# Use a uniform strategy as sum of all regrets is negative.
else:
for index in range(len(legal_actions)):
info_state_policy[legal_actions[index]] = 1.0 / len(legal_actions)
unif_policy_value = 1.0 / num_actions
info_state_policy = {legal_actions[index]:unif_policy_value
for index in range(num_actions)}
return info_state_policy


def _update_average_policy(average_policy, info_state_nodes):
"""Updates in place `average_policy` to the average of all policies iterated.

Expand Down Expand Up @@ -617,10 +615,7 @@ def array_to_strat_dict(strategy_array, legal_actions):
Returns:
strategy_dictionary: a dictionary action -> prob value.
"""
strategy_dictionary = {}
for action in legal_actions:
strategy_dictionary[action] = strategy_array[action]
return strategy_dictionary
return dict(zip(legal_actions, strategy_array))


def create_probs_from_index(indices, current_policy):
Expand Down
Loading