-
Notifications
You must be signed in to change notification settings - Fork 0
/
MCTS.py
49 lines (38 loc) · 1.57 KB
/
MCTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import trange
import random
import math
from Node import Node
class MCTS:
def __init__(self, game, args, model):
self.game = game
self.args = args
self.model = model
@torch.no_grad()
def search(self, state):
root = Node(self.game, self.args, state)
for search in range(self.args['num_searches']):
node = root
while node.is_fully_expanded():
node = node.select()
value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
value = self.game.get_opponent_value(value)
if not is_terminal:
policy, value = self.model(
torch.tensor(self.game.get_encoded_state(node.state)).unsqueeze(0)
)
policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
valid_moves = self.game.get_valid_moves(node.state)
policy *= valid_moves
policy /= np.sum(policy)
value = value.item()
node.expand(policy)
node.backpropagate(value)
action_probs = np.zeros(self.game.action_size)
for child in root.children:
action_probs[child.action_taken] = child.visit_count
action_probs /= np.sum(action_probs)
return action_probs