Skip to content

Commit

Permalink
Added dirichlet noise by jrbuhl93 (suragnair#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
aps0127 authored and aps0127 committed Jun 24, 2020
1 parent 5ce393c commit b173f6b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Coach.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def AsyncSelfPlay(game,args,iter_num,bar):
net.load_checkpoint(folder=args.checkpoint, filename='best.pth.tar')
except:
pass
mcts = MCTS(game, net,args)
mcts = MCTS(game, net, args, dirichlet_noise=True)

# create a list for store game state
returnlist = []
Expand Down
22 changes: 19 additions & 3 deletions MCTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ class MCTS():
This class handles the MCTS tree.
"""

def __init__(self, game, nnet, args):
def __init__(self, game, nnet, args, dirichlet_noise=False):
self.game = game
self.nnet = nnet
self.args = args
self.dirichlet_noise = dirichlet_noise
self.Qsa = {} # stores Q values for s,a (as defined in the paper)
self.Nsa = {} # stores #times edge s,a was visited
self.Ns = {} # stores #times board s was visited
Expand All @@ -31,7 +32,8 @@ def getActionProb(self, canonicalBoard, temp=1):
proportional to Nsa[(s,a)]**(1./temp)
"""
for i in range(self.args.numMCTSSims):
self.search(canonicalBoard)
dir_noise = (i == 0 and self.dirichlet_noise)
self.search(canonicalBoard, dirichlet_noise=dir_noise)

s = self.game.stringRepresentation(canonicalBoard)
counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]
Expand All @@ -47,7 +49,7 @@ def getActionProb(self, canonicalBoard, temp=1):
return probs


def search(self, canonicalBoard):
def search(self, canonicalBoard, dirichlet_noise=False):
"""
This function performs one iteration of MCTS. It is recursively called
till a leaf node is found. The action chosen at each node is one that
Expand Down Expand Up @@ -80,6 +82,8 @@ def search(self, canonicalBoard):
self.Ps[s], v = self.nnet.predict(canonicalBoard)
valids = self.game.getValidMoves(canonicalBoard, 1)
self.Ps[s] = self.Ps[s]*valids # masking invalid moves
if self.dirichlet_noise:
self.applyDirNoise(s, valids)
sum_Ps_s = np.sum(self.Ps[s])
if sum_Ps_s > 0:
self.Ps[s] /= sum_Ps_s # renormalize
Expand All @@ -98,6 +102,10 @@ def search(self, canonicalBoard):
return -v

valids = self.Vs[s]
if dirichlet_noise:
self.applyDirNoise(s, valids)
sum_Ps_s = np.sum(self.Ps[s])
self.Ps[s] /= sum_Ps_s # renormalize
cur_best = -float('inf')
best_act = -1

Expand Down Expand Up @@ -129,3 +137,11 @@ def search(self, canonicalBoard):

self.Ns[s] += 1
return -v

def applyDirNoise(self, s, valids):
dir_values = np.random.dirichlet([self.args.dirichletAlpha] * np.count_nonzero(valids))
dir_idx = 0
for idx in range(len(self.Ps[s])):
if self.Ps[s][idx]:
self.Ps[s][idx] = (0.75 * self.Ps[s][idx]) + (0.25 * dir_values[dir_idx])
dir_idx += 1
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
'numPerProcessAgainst': 15,
'checkpoint': 'temp/Uno/',
'numItersForTrainExamplesHistory': 20,

'dirichletAlpha': 0.6 # α = {0.3, 0.15, 0.03} for chess, shogi and Go respectively, scaled in inverse proportion to the approximate number of legal moves in a typical position
})

if __name__=="__main__":
Expand Down

0 comments on commit b173f6b

Please sign in to comment.