Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DQN bug (wrong target Q-value for illegal actions) #1259

Merged
merged 1 commit into from
Aug 5, 2024

Conversation

nathanlct
Copy link
Contributor

Came across a bug in the Pytorch DQN implementation in open_spiel/python/pytorch/dqn.py.

TLDR: I replaced sys.float_info.min with torch.finfo(torch.float).min which is the minimum value of a float32.

This code computes the max Q-target, setting illegal actions' Q-values to a large negative value so that they cannot be considered in the max:

illegal_actions_mask = 1 - legal_actions_mask
legal_target_q_values = self._target_q_values.masked_fill(
    illegal_actions_mask.bool(), ILLEGAL_ACTION_LOGITS_PENALTY
)
max_next_q = torch.max(legal_target_q_values, dim=1)[0]

However ILLEGAL_ACTION_LOGITS_PENALTY is set to sys.float_info.min which (surprisingly) is a positive number very close to 0 (see https://docs.python.org/3/library/sys.html#sys.float_info.min).

Python 3.8.19
>>> import sys
>>> sys.float_info.min
2.2250738585072014e-308
>>> import torch
>>> torch.finfo(torch.float).min
-3.4028234663852886e+38
>>> torch.finfo(torch.float32).min
-3.4028234663852886e+38

The Tensorflow DQN implementation in open_spiel/python/algorithms/dqn.py is correct though: ILLEGAL_ACTION_LOGITS_PENALTY = -1e9

I ran a DQN best response against a Phantom Tic-Tac-Toe PPO policy and got a pretty significant difference of 0.1 in exploitability (consistant across several seeds):

image

blue: tensorflow DQN, orange: torch DQN before fix, green: torch DQN after fix

Fix DQN bug: set ILLEGAL_ACTION_LOGITS_PENALTY to a large negative number instead of 0.
Copy link

google-cla bot commented Aug 2, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@lanctot
Copy link
Collaborator

lanctot commented Aug 2, 2024

Wow, amazing find. Good catch! Thanks, especially for the thorough evidence!

It's really important to get these contributions from the community for stuff like this because we don't use the PyTorch implementations for our research.

Could I ask you to take a quick look at the DQN C++ implementation? (Also based on LibTorch: https://github.com/google-deepmind/open_spiel/tree/master/open_spiel/algorithms/dqn_torch) just to be sure that it doesn't also have the same issue?

@lanctot
Copy link
Collaborator

lanctot commented Aug 2, 2024

Nevermind, just checked! We're good in the C++ DQN:

std::numeric_limits<float>::lowest();

and

https://en.cppreference.com/w/cpp/types/numeric_limits/lowest

@lanctot lanctot added imported This PR has been imported and awaiting internal review. Please avoid any more local changes, thanks! merged internally The code is now submitted to our internal repo and will be merged in the next github sync. labels Aug 5, 2024
@lanctot lanctot merged commit 7ff9e28 into google-deepmind:master Aug 5, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
imported This PR has been imported and awaiting internal review. Please avoid any more local changes, thanks! merged internally The code is now submitted to our internal repo and will be merged in the next github sync.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants