You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So, I have tried to implement this myself using Givens rotations. But due to N*M loops it is very slow, which makes it useless since the main goal of this function is to calculate QR faster than otherwise calling jax.scipy.linalg.qr. I guess this would be useful for people using jax for trust region Newton method optimization. I thought of using Householder reflections but I couldn't implement it under jax.jit. Can you add this utility to jax or give some feedback on my implementation to make it faster?
def_givens_jax(a, b):
b_zero=abs(b) ==0a_lt_b=abs(a) <abs(b)
t=-jnp.where(a_lt_b, a, b) /jnp.where(a_lt_b, b, a)
r=rsqrt(1+abs(t) **2).astype(t.dtype)
cs=jnp.where(b_zero, 1, jnp.where(a_lt_b, r*t, r))
sn=jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r*t))
G2=jnp.array([[cs, -sn], [sn, cs]])
returnG2.astype(float)
@jax.jitdefupdate_qr_jax(A, w, q, r):
"""Update QR factorization with a diagonal matrix w at the bottom."""m, n=A.shapeQ=jnp.eye(m+n)
Q=Q.at[:m, :m].set(q)
R=jnp.vstack([r, w])
defbody_inner(i, jQR):
j, Q, R=jQRi=m+j-ia, b=R[i-1, j], R[i, j]
G2=_givens_jax(a, b)
R=R.at[jnp.array([i-1, i])].set(G2 @ R[jnp.array([i-1, i])])
Q=Q.at[:, jnp.array([i-1, i])].set(Q[:, jnp.array([i-1, i])] @ G2.T)
returnj, Q, Rdefbody(j, QR):
Q, R=QRj, Q, R=fori_loop(0, m, body_inner, (j, Q, R))
returnQ, RQ, R=fori_loop(0, n, body, (Q, R))
R=jnp.where(jnp.abs(R) <1e-10, 0, R)
returnQ, R
Note: I also tried economic mode QR to reduce matrix size, but this is still slow.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
For my application, I need to take QR of
where I already have
So, I have tried to implement this myself using Givens rotations. But due to
N*M
loops it is very slow, which makes it useless since the main goal of this function is to calculate QR faster than otherwise callingjax.scipy.linalg.qr
. I guess this would be useful for people usingjax
for trust region Newton method optimization. I thought of using Householder reflections but I couldn't implement it underjax.jit
. Can you add this utility tojax
or give some feedback on my implementation to make it faster?Note: I also tried
economic
mode QR to reduce matrix size, but this is still slow.Beta Was this translation helpful? Give feedback.
All reactions