Skip to content

Commit

Permalink
anh fix for stream behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
phinate committed Jul 19, 2023
1 parent feb0639 commit ac107ae
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/signax/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,22 @@ def _body(carry, path_inc):
return ret, ret

carry, stacked = jax.lax.scan(f=_body, init=exp_term, xs=path_increments[1:])

if stream:
res = [
jnp.concatenate([first[None, ...], rest], axis=0)
for first, rest in zip(exp_term, stacked)
]
# here `res` has shape [(patch_len - 1, dim), (patch_len - 1, dim, dim), ...]
if flatten:
res = jax.vmap(lambda x: flatten_util.ravel_pytree(x)[0])(res)
# now `res` has shape (patch_len -1, dim + dim * dim + ...)
else:
res = carry
if flatten:
res = flatten_util.ravel_pytree(res)[0]
# `res` has shape [(dim,), (dim, dim), ...]
if flatten:
res = flatten_util.ravel_pytree(res)[
0
] # `res` has shape (dim + dim * dim + ..., )
return res


Expand Down

0 comments on commit ac107ae

Please sign in to comment.