-
Notifications
You must be signed in to change notification settings - Fork 3
/
model_loader.py
67 lines (59 loc) · 2.77 KB
/
model_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from collections import defaultdict
# find all possible acquisition functions that is needed to run the frameworks
def get_acq_function(framework_id=None,
aggregation=None,
problem=None,
n_dir=-1):
acq_list = []
framework_acq_dict = defaultdict(list)
if len(list(set(['11', '12', '21', '22']).intersection(framework_id))) > 0:
for i in range(problem.n_obj):
string = "f" + str(i + 1)
acq_list.append(string)
framework_acq_dict['11'].append(string)
framework_acq_dict['12'].append(string)
framework_acq_dict['21'].append(string)
framework_acq_dict['22'].append(string)
if len(list(set(['11', '12', '31', '32']).intersection(framework_id))) > 0:
for j in range(problem.n_constr):
string = "g" + str(j + 1)
acq_list.append(string)
framework_acq_dict['11'].append(string)
framework_acq_dict['12'].append(string)
framework_acq_dict['31'].append(string)
framework_acq_dict['32'].append(string)
if len(list(set(['21', '22', '41', '42']).intersection(framework_id))) > 0:
for i in aggregation['G']:
string = "G_" + str(i)
acq_list.append(string)
framework_acq_dict['21'].append(string)
framework_acq_dict['22'].append(string)
framework_acq_dict['41'].append(string)
framework_acq_dict['42'].append(string)
if len(list(set(['31', '32', '41', '42']).intersection(framework_id))) > 0:
for i in range(n_dir):
for j in aggregation['l']:
string = "l" + str(i + 1) + '_'+str(j)
acq_list.append(string)
framework_acq_dict['31'].append(string)
framework_acq_dict['32'].append(string)
framework_acq_dict['41'].append(string)
framework_acq_dict['42'].append(string)
if len(list(set(['5']).intersection(framework_id))) > 0:
for i in range(n_dir):
for j in aggregation['fg_M5']:
string = "fg_M5_" + str(i + 1) + '_'+str(j)
acq_list.append(string)
framework_acq_dict['5'].append(string)
if len(list(set(['6B', '6A']).intersection(framework_id))) > 0:
for j in aggregation['fg_M6']:
if j not in ['asfcv']:
string = "fg_M6_0_"+str(j)
acq_list.append(string)
framework_acq_dict['6B'].append(string)
else:
for i in range(n_dir):
string = "fg_M6_" + str(i + 1) + '_' + str(j)
acq_list.append(string)
framework_acq_dict['6A'].append(string)
return acq_list, framework_acq_dict