Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move signature implementation to a single function, and catch shape errors #35

Merged
merged 20 commits into from
Jul 28, 2023

Conversation

phinate
Copy link
Collaborator

@phinate phinate commented Jul 12, 2023

I'm playing around with dispatching signax.signature in such a way that the correct behaviour happens depending on the shape of the path.

As an example, the function will still work on paths of shape (a, b ,c) and just (d) too, but both of these may not be the desired behaviour. So, this makes sure that if you feed in a path that is not right, we catch it as an error.

Also, if the path has a batch dim, this will broadcast the signature over the batch with vmap! flatten and num_chunks have become arguments to signature.

return carry
else:
res = carry
if flatten:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the shape of res when stream=True (shape = [(path_len-1, dim), (path_len-1, dim, dim]) is different to when stream=False (shape = [(dim,), (dim, dim)]), we may use jax.vmap(utils.flatten) when stream=True.

We vmap the first dimesion path_len - 1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

covered this by using ravel_pytree and getting rid of flatten:

import jax.numpy as jnp
from jax import flatten_util

arr = [jnp.ones((2, 3)), jnp.ones((2, 3, 4))]
tree = flatten_util.ravel_pytree(arr)[0]
print(tree)
# > [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]

Comment on lines 70 to 72
return [
sig_fun(path[i]) for i in range(path.shape[0])
] # otherwise, use scan to handle the batch dimension in the list
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not necessary after we handle flatten correctly in the _signature function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so i did this because i couldn't figure out how to get the batch dimension as the dimension of the list! It seems like vmap only correctly broadcasts the batch dimension to the arrays, but maybe I fed things in wrong. This is the only thing I did that gave the correct shape for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this problem does not exist when flatten=True, since it's arrays all the way down!

def logsignature(
path: Float[Array, "path_len dim"], depth: int, stream: bool = False
path: Float[Array, "path_len dim"],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you want to make logsignature to handle the input (batch_size, path_len, dim) as well.

In this case, we also need to handle the shape and need an extra vmap

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should work automatically by dispatching to those checks in signature, which should do the right things I think! But let me know if there's logic I'm missing here.

Comment on lines 24 to 25
path = rng.standard_normal(length)
signature(path, depth)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this. This actually causes error since path is 1D.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@anh-tong
Copy link
Owner

Overall, it looks good. When the behaviour of flatten in _signature function is corrected and allowing logsignature to handle inputs with (batch_size, path_len, dim) is added, we can merge this to main.

Also, can you please check why CI format has errors?
image

@phinate
Copy link
Collaborator Author

phinate commented Jul 16, 2023

ok thinks look mostly okay now, i think i've caught the majority of errors -- there are still a couple tests failing with stream=True -- do you think i'm handling the flattening logic wrong somewhere?

@anh-tong
Copy link
Owner

anh-tong commented Jul 17, 2023

From the comment

# if not flattening, use list comprehension since vmap cant handle list outputs

I'm curious about why we cannot vmap a function returning a list.

I test with a simple example like this:

import jax
import jax.numpy as jnp
from jax import flatten_util

def fn(x, flatten=False):
    n = x.shape[0]
    res = [jnp.ones((n, )), jnp.ones((n, n))]
    if flatten:
        res = flatten_util.ravel_pytree(res)[0]
    return res

# normal case
x = jnp.ones((2,))
output = fn(x)
[print(o.shape) for o in output]
# returned shape: [(2,), (2, 2)]

# vmap
print("--- vmap ---")
batch_size = 5
x = jnp.ones((batch_size, 2))
output = jax.vmap(fn)(x)
[print(o.shape) for o in output]
# return [(batch_size, 2), (batch_size, 2, 2)]


# flatten + vmap
print("--- flatten + vmap ---")
batch_size = 5
x = jnp.ones((batch_size, 2))
output = jax.vmap(lambda _x: fn(_x, flatten=True))(x)
print(output.shape)
# return (batch_size, 2 + 2 * 2)

We can make _signature handle all the arguments (stream, flatten, ...). The remaining job of signature function is just to check the shape and dispatch between _signature and vmapped _signature

I suggest to handle flatten for the case of stream=True and stream=False in this line:

if stream:
res = [
jnp.concatenate([first[None, ...], rest], axis=0)
for first, rest in zip(exp_term, stacked)
]
else:
res = carry
if flatten:
res = flatten_util.ravel_pytree(res)[0]

with something like this

if stream:
      res = [
          jnp.concatenate([first[None, ...], rest], axis=0)
          for first, rest in zip(exp_term, stacked)
      ]
    # here `res` has shape [(patch_len - 1, dim), (patch_len - 1, dim, dim), ...]
    if flatten:
        res =jax.vmap(lambda _x: flatten_util.ravel_pytree(x)[0])(res)
        # now `res` has shape (patch_len -1, dim + dim * dim + ...)
  else:
      res = carry
      # `res` has shape [(dim,), (dim, dim), ...]
    if flatten:
        res = flatten_util.ravel_pytree(res)[0] # `res` has shape (dim + dim * dim + ..., )

I believe the following line works well. For consistency, probably we can change from signature function to _signature which designates to handle 2D input.

remainder_signature = signature(path_remainder, depth, stream)

@phinate
Copy link
Collaborator Author

phinate commented Jul 19, 2023

I'm curious about why we cannot vmap a function returning a list.
I test with a simple example like this:

import jax
import jax.numpy as jnp
from jax import flatten_util

def fn(x, flatten=False):
    n = x.shape[0]
    res = [jnp.ones((n, )), jnp.ones((n, n))]
    if flatten:
        res = flatten_util.ravel_pytree(res)[0]
    return res

# normal case
x = jnp.ones((2,))
output = fn(x)
[print(o.shape) for o in output]
# returned shape: [(2,), (2, 2)]

# vmap
print("--- vmap ---")
batch_size = 5
x = jnp.ones((batch_size, 2))
output = jax.vmap(fn)(x)
[print(o.shape) for o in output]
# return [(batch_size, 2), (batch_size, 2, 2)]


# flatten + vmap
print("--- flatten + vmap ---")
batch_size = 5
x = jnp.ones((batch_size, 2))
output = jax.vmap(lambda _x: fn(_x, flatten=True))(x)
print(output.shape)
# return (batch_size, 2 + 2 * 2)

It's the second case that I implemented this line for -- would you not expect [(2), (2, 2)]*batch_size instead of [(batch_size, 2), (batch_size, 2, 2)] here? That's the hard thing to do with vmap.

I suggest to handle flatten for the case of stream=True and stream=False in this line:

if stream:
res = [
jnp.concatenate([first[None, ...], rest], axis=0)
for first, rest in zip(exp_term, stacked)
]
else:
res = carry
if flatten:
res = flatten_util.ravel_pytree(res)[0]

with something like this

if stream:
      res = [
          jnp.concatenate([first[None, ...], rest], axis=0)
          for first, rest in zip(exp_term, stacked)
      ]
    # here `res` has shape [(patch_len - 1, dim), (patch_len - 1, dim, dim), ...]
    if flatten:
        res =jax.vmap(lambda _x: flatten_util.ravel_pytree(x)[0])(res)
        # now `res` has shape (patch_len -1, dim + dim * dim + ...)
  else:
      res = carry
      # `res` has shape [(dim,), (dim, dim), ...]
    if flatten:
        res = flatten_util.ravel_pytree(res)[0] # `res` has shape (dim + dim * dim + ..., )

cool, I'll look at this now -- thanks for your help @anh-tong :D

@codecov
Copy link

codecov bot commented Jul 21, 2023

Codecov Report

Merging #35 (dd64894) into main (b867de3) will increase coverage by 1.08%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main      #35      +/-   ##
==========================================
+ Coverage   92.85%   93.93%   +1.08%     
==========================================
  Files           7        7              
  Lines         224      264      +40     
==========================================
+ Hits          208      248      +40     
  Misses         16       16              
Files Changed Coverage Δ
src/signax/__init__.py 100.00% <ø> (ø)
src/signax/utils.py 75.92% <ø> (-1.67%) ⬇️
src/signax/module.py 100.00% <100.00%> (ø)
src/signax/signatures.py 100.00% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@phinate
Copy link
Collaborator Author

phinate commented Jul 21, 2023

stream fix worked! but the code's a bit messy now -- would need to think about how to handle things for logsignature (i currently force no flattening since signature_to_logsignature needs things in list form), but I could probably re-write the flatten logic to happen once in the same way for all functions as opposed to pasting the same code everywhere...

@anh-tong
Copy link
Owner

anh-tong commented Jul 25, 2023

Hi @phinate, sorry for taking so long. I have updated the code in the following:

  • function _signature, _logsignature just handle the input with shape (length, channels)
  • function signature and logsignature built on top of _signature and _logsignature just perform shape handling.
  • flatten=True now is the default option like signatory

Indeed, keeping track of the shape can be challenging, especially when considering the stream and flatten options being turned on or off.

Let me know if the behavior of current version is what you expect.

@anh-tong
Copy link
Owner

Let's merge this. Thanks for PR.

@anh-tong anh-tong merged commit 02e8591 into main Jul 28, 2023
7 checks passed
@phinate phinate deleted the shape-handling branch August 13, 2023 12:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants