-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move signature implementation to a single function, and catch shape errors #35
Conversation
src/signax/signatures.py
Outdated
return carry | ||
else: | ||
res = carry | ||
if flatten: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the shape of res
when stream=True
(shape = [(path_len-1, dim), (path_len-1, dim, dim]
) is different to when stream=False
(shape = [(dim,), (dim, dim)]
), we may use jax.vmap(utils.flatten)
when stream=True
.
We vmap the first dimesion path_len - 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
covered this by using ravel_pytree
and getting rid of flatten
:
import jax.numpy as jnp
from jax import flatten_util
arr = [jnp.ones((2, 3)), jnp.ones((2, 3, 4))]
tree = flatten_util.ravel_pytree(arr)[0]
print(tree)
# > [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
src/signax/signatures.py
Outdated
return [ | ||
sig_fun(path[i]) for i in range(path.shape[0]) | ||
] # otherwise, use scan to handle the batch dimension in the list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not necessary after we handle flatten
correctly in the _signature
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so i did this because i couldn't figure out how to get the batch dimension as the dimension of the list! It seems like vmap only correctly broadcasts the batch dimension to the arrays, but maybe I fed things in wrong. This is the only thing I did that gave the correct shape for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this problem does not exist when flatten=True
, since it's arrays all the way down!
def logsignature( | ||
path: Float[Array, "path_len dim"], depth: int, stream: bool = False | ||
path: Float[Array, "path_len dim"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you want to make logsignature
to handle the input (batch_size, path_len, dim)
as well.
In this case, we also need to handle the shape and need an extra vmap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should work automatically by dispatching to those checks in signature
, which should do the right things I think! But let me know if there's logic I'm missing here.
tests/test_against_signatory.py
Outdated
path = rng.standard_normal(length) | ||
signature(path, depth) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this. This actually causes error since path
is 1D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
ok thinks look mostly okay now, i think i've caught the majority of errors -- there are still a couple tests failing with |
From the comment signax/src/signax/signatures.py Line 72 in feb0639
I'm curious about why we cannot vmap a function returning a list.
I test with a simple example like this: import jax
import jax.numpy as jnp
from jax import flatten_util
def fn(x, flatten=False):
n = x.shape[0]
res = [jnp.ones((n, )), jnp.ones((n, n))]
if flatten:
res = flatten_util.ravel_pytree(res)[0]
return res
# normal case
x = jnp.ones((2,))
output = fn(x)
[print(o.shape) for o in output]
# returned shape: [(2,), (2, 2)]
# vmap
print("--- vmap ---")
batch_size = 5
x = jnp.ones((batch_size, 2))
output = jax.vmap(fn)(x)
[print(o.shape) for o in output]
# return [(batch_size, 2), (batch_size, 2, 2)]
# flatten + vmap
print("--- flatten + vmap ---")
batch_size = 5
x = jnp.ones((batch_size, 2))
output = jax.vmap(lambda _x: fn(_x, flatten=True))(x)
print(output.shape)
# return (batch_size, 2 + 2 * 2) We can make I suggest to handle signax/src/signax/signatures.py Lines 115 to 123 in feb0639
with something like this if stream:
res = [
jnp.concatenate([first[None, ...], rest], axis=0)
for first, rest in zip(exp_term, stacked)
]
# here `res` has shape [(patch_len - 1, dim), (patch_len - 1, dim, dim), ...]
if flatten:
res =jax.vmap(lambda _x: flatten_util.ravel_pytree(x)[0])(res)
# now `res` has shape (patch_len -1, dim + dim * dim + ...)
else:
res = carry
# `res` has shape [(dim,), (dim, dim), ...]
if flatten:
res = flatten_util.ravel_pytree(res)[0] # `res` has shape (dim + dim * dim + ..., ) I believe the following line works well. For consistency, probably we can change from signax/src/signax/signatures.py Line 214 in feb0639
|
It's the second case that I implemented this line for -- would you not expect
cool, I'll look at this now -- thanks for your help @anh-tong :D |
Codecov Report
@@ Coverage Diff @@
## main #35 +/- ##
==========================================
+ Coverage 92.85% 93.93% +1.08%
==========================================
Files 7 7
Lines 224 264 +40
==========================================
+ Hits 208 248 +40
Misses 16 16
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
stream fix worked! but the code's a bit messy now -- would need to think about how to handle things for logsignature (i currently force no flattening since |
Hi @phinate, sorry for taking so long. I have updated the code in the following:
Indeed, keeping track of the shape can be challenging, especially when considering the stream and flatten options being turned on or off. Let me know if the behavior of current version is what you expect. |
Let's merge this. Thanks for PR. |
I'm playing around with dispatching
signax.signature
in such a way that the correct behaviour happens depending on the shape of the path.As an example, the function will still work on paths of shape (a, b ,c) and just (d) too, but both of these may not be the desired behaviour. So, this makes sure that if you feed in a path that is not right, we catch it as an error.
Also, if the path has a batch dim, this will broadcast the signature over the batch with vmap!
flatten
andnum_chunks
have become arguments tosignature
.