Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: "create tridiagonal_solve in Jax" #26235

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions ivy/functional/frontends/jax/lax/tridiagonal_solve
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import ivy
import ivy.numpy as ivy_np
from ivy.functional.frontends.jax.func_wrapper import to_ivy_arrays_and_back
from jax import lax

# ... (existing code)

@to_ivy_arrays_and_back
def svd(x, /, *, full_matrices=True, compute_uv=True):
if not compute_uv:
return ivy.svdvals(x)
return ivy.svd(x, full_matrices=full_matrices)

# Add your tridiagonal_solve function here:
@to_ivy_arrays_and_back
def tridiagonal_solve(L, D, U, b):
"""
Solves a tridiagonal linear system Ax = b given L, D, and U.

Parameters:
- L: Lower diagonal of the tridiagonal matrix A.
- D: Diagonal of the tridiagonal matrix A.
- U: Upper diagonal of the tridiagonal matrix A.
- b: Right-hand side vector.

Returns:
- x: Solution vector.
"""
n = len(D)

# Check input dimensions
if L.shape != (n - 1,) or U.shape != (n - 1,) or D.shape != (n,):
raise ValueError("Input dimensions do not match the tridiagonal matrix size.")

# Check if D contains any zero elements (avoid division by zero)
if ivy_np.any(D == 0):
raise ValueError("Diagonal elements of the tridiagonal matrix must not be zero.")

x = ivy_np.zeros_like(b)

# Forward substitution
for i in range(1, n):
if D[i] == 0:
raise ValueError("Diagonal element D[{}] is zero. Division by zero is not allowed.".format(i))
x = lax.index_update(x, i, (b[i] - L[i - 1] * x[i - 1]) / D[i])

# Backward substitution
for i in range(n - 2, -1, -1):
x = lax.index_update(x, i, x[i] - U[i] * x[i + 1])

return x
Loading