Skip to content

Commit

Permalink
Fixing and tuning ES solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
Akon32 committed Jul 9, 2023
1 parent 82da6fb commit bd7579f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
4 changes: 2 additions & 2 deletions py/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def create_inhibition_matrix(mus_places: np.ndarray,

if pillar_center_radius.shape[0] > 0:
block_coords = np.append(mus_places, pillar_center_radius[:, 0:2], axis=0)
block_radiuses = np.append(np.ones((n_mus, 1)) * mus_radius, pillar_center_radius[:, 2], axis=0)
block_radiuses = np.append(np.ones(n_mus) * mus_radius, pillar_center_radius[:, 2], axis=0)
else:
block_coords = mus_places
block_radiuses = np.ones((n_mus, 1)) * mus_radius
Expand Down Expand Up @@ -130,7 +130,7 @@ def random_inds(arr, ratio):
mus_places_volumes[mus_inds],
att_places[att_inds],
att_tastes[att_inds],
pillar_center_radius[pillar_inds] if pillar_inds else np.array([]),
pillar_center_radius[pillar_inds] if len(pillar_inds) > 0 else np.array([]),
use_playing_together_ext)
score_sum += sc

Expand Down
16 changes: 9 additions & 7 deletions py/solving.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from dataclasses import dataclass

import numpy as np
Expand All @@ -21,7 +22,7 @@ def ES_solve(stage: Stage,
att_tastes: np.ndarray,
pillar_center_radius: np.ndarray,
use_playing_together_ext: bool = True) -> (np.ndarray, float, bool):
initial_sigmas = np.array([stage.w / 3, stage.h / 3, 5 / 3], dtype='float64')
initial_sigmas = np.array([stage.w / 3 / 2, stage.h / 3 / 2, 5 / 3 / 2], dtype='float64')
step_count = 100
population_size = 20
score_k = 1
Expand All @@ -33,10 +34,10 @@ def ES_solve(stage: Stage,
def call_score(mus_places_volumes: np.ndarray) -> (float, bool, float, float):
sc = mc_score(mus_instruments, mus_places_volumes, att_places, att_tastes, pillar_center_radius,
use_playing_together_ext,
mus_ratio=0.05,
att_ratio=0.05,
pillar_ratio=0.05,
n_eval=50)
mus_ratio=0.01,
att_ratio=0.01,
pillar_ratio=0.01,
n_eval=30)
mus_dist_penalty = musicians_distance_penalty(mus_places_volumes)
mus_scene_penalty = musicians_out_of_scene_penalty(stage, mus_places_volumes)
valid = mus_dist_penalty <= 0 and mus_scene_penalty <= 0
Expand All @@ -55,14 +56,15 @@ def gen_one():
current_solution = initial_solution(stage, n_musicians, [0.001, 0.001, 0.1])
current_score = call_score(current_solution)
for step in range(step_count):
sigmas = initial_sigmas - initial_sigmas * step / step_count
# sigmas = initial_sigmas * (1.0 - step / step_count)
sigmas = initial_sigmas * (math.exp(-step / step_count * 6))
print(f"Step {step}/{step_count} Approx.score={current_score} sigmas={sigmas}")
population = gen_population(current_solution, sigmas)
# TODO: make solutions valid / filter invalid

for solution in population:
sol_score = call_score(solution)
if sol_score[0] > current_score[0] and sol_score[1] >= current_score[1]:
if sol_score[0] > current_score[0] and (step < step_count * 0.7 or sol_score[1] >= current_score[1]):
current_solution = solution
current_score = sol_score
print(f"Finished {step_count} steps. Approx.score={current_score}")
Expand Down

0 comments on commit bd7579f

Please sign in to comment.