-
Notifications
You must be signed in to change notification settings - Fork 0
/
shooting_tests.py
99 lines (91 loc) · 3.71 KB
/
shooting_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from numerical_shooting import orbit_calc
import numpy as np
from ode_functions import hopf_bf, hopf_exp, ode_3, ode_3_exp, predator_prey, ode_1
from integrate_ode import solve_ode
# Function to test Hopf bifurcation shooting result against explicit solution
def hopf_test(function, u0, phase):
print('--------- Now running the Hopf bifurcation normal-form shooting tests (suitable inputs) ---------')
test_result = orbit_calc(function, u0, var=phase)
T = test_result[-1]
u0 = test_result[:-1]
pass_count = 0
fail_count = 0
real_T = 2 * np.pi
if np.allclose(T, real_T):
pass_count = pass_count + 1
print('Correct period found - PASS')
else:
print('Incorrect period found - FAIL')
fail_count = fail_count + 1
t = np.linspace(0, T, 100)
sol = solve_ode(function, u0, t, 'rk4', 0.01)
point0 = [sol[0][0], sol[1][0]]
pointf = [sol[0][-1], sol[1][-1]]
if np.allclose(point0, pointf):
pass_count = pass_count + 1
print('Initial points match end points - PASS')
else:
fail_count = fail_count + 1
print('Initial and end points do not match - FAIL')
exp_sol = hopf_exp(t, theta=np.pi, b=1)
real_point0 = [exp_sol[0][0], exp_sol[1][0]]
real_pointf = [exp_sol[0][-1], exp_sol[1][-1]]
if np.allclose(point0, real_point0):
pass_count = pass_count + 1
print('Calculated and explicit solution initial points match - PASS')
else:
fail_count = fail_count + 1
print('Calculated and explicit solution initial points do not match - FAIL')
if np.allclose(pointf, real_pointf):
pass_count = pass_count + 1
print('Calculated and explicit solution end points match - PASS')
else:
fail_count = fail_count + 1
print('Calculated and explicit solution end points do not match - FAIL')
print('Tests passed: ' + str(pass_count))
print('Tests failed: ' + str(fail_count))
def du3_shoot_test(function, u0, phase):
print('--------- Now running the 3-D system shooting tests (suitable inputs) ---------')
test_result = orbit_calc(ode_3, u0, var=phase)
print(test_result)
T = test_result[-1]
u0 = test_result[:-1]
pass_count = 0
fail_count = 0
real_T = 2 * np.pi
if np.allclose(T, real_T):
pass_count = pass_count + 1
print('Correct period found - PASS')
else:
print('Incorrect period found - FAIL')
fail_count = fail_count + 1
t = np.linspace(0, T, 100)
sol = solve_ode(function, u0, t, 'rk4', 0.01)
point0 = [sol[0][0], sol[1][0]]
pointf = [sol[0][-1], sol[1][-1]]
if np.allclose(point0, pointf):
pass_count = pass_count + 1
print('Initial points match end points - PASS')
else:
fail_count = fail_count + 1
print('Initial and end points do not match - FAIL')
exp_sol = hopf_exp(t, theta=np.pi, b=1)
real_point0 = [exp_sol[0][0], exp_sol[1][0]]
real_pointf = [exp_sol[0][-1], exp_sol[1][-1]]
if np.allclose(point0, real_point0):
pass_count = pass_count + 1
print('Calculated and explicit solution initial points match - PASS')
else:
fail_count = fail_count + 1
print('Calculated and explicit solution initial points do not match - FAIL')
if np.allclose(pointf, real_pointf):
pass_count = pass_count + 1
print('Calculated and explicit solution end points match - PASS')
else:
fail_count = fail_count + 1
print('Calculated and explicit solution end points do not match - FAIL')
print('Tests passed: ' + str(pass_count))
print('Tests failed: ' + str(fail_count))
if __name__ == '__main__':
hopf_test(hopf_bf, [1, 1, 6], 0)
du3_shoot_test(ode_3, [-1, -1, 0, 6], 0)