Skip to content

Commit

Permalink
fix test svd (related to signs of singular values)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdp1 committed Jul 14, 2024
1 parent a690390 commit 40765ad
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions test/linalg/test_linalg_svd.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ module test_linalg_svd
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
call check(error, all(abs(abs(u)-abs(u_sol))<=tol), test//': U')
if (allocated(error)) return

!> [S, U]. Overwrite A matrix
Expand All @@ -104,7 +104,7 @@ module test_linalg_svd
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
call check(error, all(abs(abs(u)-abs(u_sol))<=tol), test//': U')
if (allocated(error)) return

!> [S, U, V^T]
Expand All @@ -116,9 +116,9 @@ module test_linalg_svd
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
call check(error, all(abs(abs(u)-abs(u_sol))<=tol), test//': U')
if (allocated(error)) return
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
call check(error, all(abs(abs(vt)-abs(vt_sol))<=tol), test//': V^T')
if (allocated(error)) return

!> [S, V^T]. Do not overwrite A matrix
Expand All @@ -130,7 +130,7 @@ module test_linalg_svd
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
call check(error, all(abs(abs(vt)-abs(vt_sol))<=tol), test//': V^T')
if (allocated(error)) return

!> [S, V^T]. Overwrite A matrix
Expand All @@ -141,7 +141,7 @@ module test_linalg_svd
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
call check(error, all(abs(abs(vt)-abs(vt_sol))<=tol), test//': V^T')
if (allocated(error)) return

!> [U, S, V^T].
Expand All @@ -151,11 +151,11 @@ module test_linalg_svd
test = '[U, S, V^T]'
call check(error,state%ok(),test//': '//state%print())
if (allocated(error)) return
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
call check(error, all(abs(abs(u)-abs(u_sol))<=tol), test//': U')
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
call check(error, all(abs(abs(vt)-abs(vt_sol))<=tol), test//': V^T')
if (allocated(error)) return

!> [U, S, V^T]. Partial storage -> compare until k=2 columns of U rows of V^T
Expand All @@ -167,11 +167,11 @@ module test_linalg_svd
test = '[U, S, V^T], partial storage'
call check(error,state%ok(),test//': '//state%print())
if (allocated(error)) return
call check(error, all(abs(u(:,:2)-u_sol(:,:2))<=tol) .or. all(abs(u(:,:2)+u_sol(:,:2))<=tol), test//': U(:,:2)')
call check(error, all(abs(abs(u(:,:2))-abs(u_sol(:,:2)))<=tol), test//': U(:,:2)')
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(vt(:2,:)-vt_sol(:2,:))<=tol) .or. all(abs(vt(:2,:)+vt_sol(:2,:))<=tol), test//': V^T(:2,:)')
call check(error, all(abs(abs(vt(:2,:))-abs(vt_sol(:2,:)))<=tol), test//': V^T(:2,:)')
if (allocated(error)) return

end subroutine test_svd_${ri}$
Expand Down

0 comments on commit 40765ad

Please sign in to comment.