Skip to content

Commit

Permalink
Merge pull request #35 from anh-tong/shape-handling
Browse files Browse the repository at this point in the history
Move signature implementation to a single function, and catch shape errors
  • Loading branch information
anh-tong authored Jul 28, 2023
2 parents b867de3 + dd64894 commit 02e8591
Show file tree
Hide file tree
Showing 10 changed files with 392 additions and 190 deletions.
35 changes: 17 additions & 18 deletions examples/inversion.ipynb

Large diffs are not rendered by default.

17 changes: 8 additions & 9 deletions examples/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import jax
import jax.numpy as jnp
import jax.random as jrandom
from jax import flatten_util

from signax import signature, signature_combine
from signax.module import SignatureTransform
from signax.utils import flatten


def flatten(x):
return flatten_util.ravel_pytree(x)[0]


def _make_convs(input_size: int, layer_sizes, kernel_size, *, key):
Expand Down Expand Up @@ -144,12 +148,7 @@ def _f(carry, i):

# output is a tensor algebra which is a list of `jnp.ndarray`
# size of output: [(n, dim), (n, dim, dim,), (n, dim, dim, dim), ...]

def _signature(x):
ta = signature(x, self.signature_depth)
return flatten(ta)

output = jax.vmap(_signature)(output)
output = jax.vmap(signature)(output)

return output

Expand Down Expand Up @@ -197,7 +196,7 @@ def __call__(self, x, *, key=None):
)

# signature of the first window
init = signature(x[: self.length], self.signature_depth)
init = signature(x[: self.length], self.signature_depth, flatten=False)

def f(carry, i):
"""
Expand All @@ -210,7 +209,7 @@ def f(carry, i):
start_indices=(i - 1, 0),
slice_sizes=(self.adjusted_length + 1, dim),
)
sig = signature(current_x, self.signature_depth)
sig = signature(current_x, self.signature_depth, flatten=False)
out = signature_combine(carry, sig)

# carry the current signature to the next iteration
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,6 @@ isort.required-imports = ["from __future__ import annotations"]
[tool.ruff.per-file-ignores]
"tests/**" = ["T20"]
"noxfile.py" = ["T20"]
"src/signax/signatures.py" = ["ARG001", "ARG005"] # unused arguments
"src/signax/module.py" = ["ARG002"] # unused argument key in __call__ of Module (equinox)
"examples/nets.py" = ["ARG002"] # unused argument key in __call__ of Module (equinox)
2 changes: 0 additions & 2 deletions src/signax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
"signature_combine",
"signature_to_logsignature",
"multi_signature_combine",
"signature_batch",
)

from signax import module, tensor_ops, utils
from signax.signatures import (
logsignature,
multi_signature_combine,
signature,
signature_batch,
signature_combine,
signature_to_logsignature,
)
11 changes: 8 additions & 3 deletions src/signax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from jaxtyping import Array, Float

from signax.signatures import logsignature, signature
from signax.utils import flatten


class SignatureTransform(eqx.Module):
depth: int
stream: bool
num_chunks: int = 1

def __init__(self, depth: int, stream: bool = False) -> None:
self.depth = depth
Expand All @@ -25,12 +25,15 @@ def __call__(
*,
key: Any | None = None,
) -> Array:
return flatten(signature(path, self.depth, self.stream))
return signature(
path, self.depth, self.stream, flatten=True, num_chunks=self.num_chunks
)


class LogSignatureTransform(eqx.Module):
depth: int
stream: bool
num_chunks: int = 1

def __init__(self, depth: int, stream: bool = False) -> None:
self.depth = depth
Expand All @@ -42,4 +45,6 @@ def __call__(
*,
key: Any | None = None,
) -> Array:
return flatten(logsignature(path, self.depth, self.stream))
return logsignature(
path, self.depth, self.stream, flatten=True, num_chunks=self.num_chunks
)
Loading

0 comments on commit 02e8591

Please sign in to comment.