diff --git a/Coach.py b/Coach.py index bb26f5d16..a2d577ea0 100644 --- a/Coach.py +++ b/Coach.py @@ -35,7 +35,7 @@ def AsyncSelfPlay(game,args,iter_num,bar): net.load_checkpoint(folder=args.checkpoint, filename='best.pth.tar') except: pass - mcts = MCTS(game, net,args) + mcts = MCTS(game, net, args, dirichlet_noise=True) # create a list for store game state returnlist = [] diff --git a/MCTS.py b/MCTS.py index 737458485..a67b703ee 100644 --- a/MCTS.py +++ b/MCTS.py @@ -9,10 +9,11 @@ class MCTS(): This class handles the MCTS tree. """ - def __init__(self, game, nnet, args): + def __init__(self, game, nnet, args, dirichlet_noise=False): self.game = game self.nnet = nnet self.args = args + self.dirichlet_noise = dirichlet_noise self.Qsa = {} # stores Q values for s,a (as defined in the paper) self.Nsa = {} # stores #times edge s,a was visited self.Ns = {} # stores #times board s was visited @@ -31,7 +32,8 @@ def getActionProb(self, canonicalBoard, temp=1): proportional to Nsa[(s,a)]**(1./temp) """ for i in range(self.args.numMCTSSims): - self.search(canonicalBoard) + dir_noise = (i == 0 and self.dirichlet_noise) + self.search(canonicalBoard, dirichlet_noise=dir_noise) s = self.game.stringRepresentation(canonicalBoard) counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())] @@ -47,7 +49,7 @@ def getActionProb(self, canonicalBoard, temp=1): return probs - def search(self, canonicalBoard): + def search(self, canonicalBoard, dirichlet_noise=False): """ This function performs one iteration of MCTS. It is recursively called till a leaf node is found. The action chosen at each node is one that @@ -80,6 +82,8 @@ def search(self, canonicalBoard): self.Ps[s], v = self.nnet.predict(canonicalBoard) valids = self.game.getValidMoves(canonicalBoard, 1) self.Ps[s] = self.Ps[s]*valids # masking invalid moves + if self.dirichlet_noise: + self.applyDirNoise(s, valids) sum_Ps_s = np.sum(self.Ps[s]) if sum_Ps_s > 0: self.Ps[s] /= sum_Ps_s # renormalize @@ -98,6 +102,10 @@ def search(self, canonicalBoard): return -v valids = self.Vs[s] + if dirichlet_noise: + self.applyDirNoise(s, valids) + sum_Ps_s = np.sum(self.Ps[s]) + self.Ps[s] /= sum_Ps_s # renormalize cur_best = -float('inf') best_act = -1 @@ -129,3 +137,11 @@ def search(self, canonicalBoard): self.Ns[s] += 1 return -v + + def applyDirNoise(self, s, valids): + dir_values = np.random.dirichlet([self.args.dirichletAlpha] * np.count_nonzero(valids)) + dir_idx = 0 + for idx in range(len(self.Ps[s])): + if self.Ps[s][idx]: + self.Ps[s][idx] = (0.75 * self.Ps[s][idx]) + (0.25 * dir_values[dir_idx]) + dir_idx += 1 \ No newline at end of file diff --git a/main.py b/main.py index 34ae32883..448bbdc7c 100644 --- a/main.py +++ b/main.py @@ -28,6 +28,8 @@ 'numPerProcessAgainst': 15, 'checkpoint': 'temp/Uno/', 'numItersForTrainExamplesHistory': 20, + + 'dirichletAlpha': 0.6 # α = {0.3, 0.15, 0.03} for chess, shogi and Go respectively, scaled in inverse proportion to the approximate number of legal moves in a typical position }) if __name__=="__main__":