Skip to content

Commit

Permalink
Fix typo in IS-MCTS test and add default support for ResampleFromInfo…
Browse files Browse the repository at this point in the history
…state for perfect information game (defaults to State::Clone()).

PiperOrigin-RevId: 657678805
Change-Id: Ib2560b2cb0a92636a82c7e8efbf84d6d3b6c6b85
  • Loading branch information
lanctot committed Jul 30, 2024
1 parent 02021c3 commit 3cdecf7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
11 changes: 8 additions & 3 deletions open_spiel/python/algorithms/ismcts_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,21 @@
"""Test the IS-MCTS Agent."""

from absl.testing import absltest
from absl.testing import parameterized
from open_spiel.python import rl_environment
from open_spiel.python.algorithms import ismcts
from open_spiel.python.algorithms import mcts
from open_spiel.python.algorithms import mcts_agent


class MCTSAgentTest(absltest.TestCase):
class MCTSAgentTest(parameterized.TestCase):

def test_tic_tac_toe_episode(self):
env = rl_environment.Environment("kuhn_poker", include_full_state=True)
@parameterized.named_parameters(
dict(testcase_name="tic_tac_toe", game_string="kuhn_poker"),
dict(testcase_name="leduc_poker", game_string="leduc_poker"),
)
def test_self_play_episode(self, game_string: str):
env = rl_environment.Environment(game_string, include_full_state=True)
num_players = env.num_players
num_actions = env.action_spec()["num_actions"]

Expand Down
10 changes: 10 additions & 0 deletions open_spiel/spiel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ StateType State::GetType() const {
}
}

std::unique_ptr<State> State::ResampleFromInfostate(
int player_id,
std::function<double()> rng) const {
if (GetGame()->GetType().information ==
GameType::Information::kPerfectInformation) {
return Clone();
}
SpielFatalError("ResampleFromInfostate() not implemented.");
}

bool GameType::ContainsRequiredParameters() const {
for (const auto& key_val : parameter_specification) {
if (key_val.second.is_mandatory()) {
Expand Down
7 changes: 4 additions & 3 deletions open_spiel/spiel.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,10 +654,11 @@ class State {
// be interpreted as a cumulative distribution function, and will be used to
// sample from the legal chance actions. A good choice would be
// absl/std::uniform_real_distribution<double>(0., 1.).
//
// Default implementation checks if the game is a perfect information game.
// If so, it returns a clone, otherwise an error is thrown.
virtual std::unique_ptr<State> ResampleFromInfostate(
int player_id, std::function<double()> rng) const {
SpielFatalError("ResampleFromInfostate() not implemented.");
}
int player_id, std::function<double()> rng) const;

// Returns a vector of states & probabilities that are consistent with the
// infostate from the view of the current player. By default, this is not
Expand Down

0 comments on commit 3cdecf7

Please sign in to comment.