Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CI and log signature test in word mode #57

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ jobs:
python -m pip install --upgrade setuptools pip wheel
python -m pip install .[test]

- name: Install heavy deps for testing (signatory + torch)
- name: Install iisignature for testing
run: |
python -m pip install torch==1.9.0
python -m pip install signatory==1.2.6.1.9.0 --no-cache-dir --force-reinstall
python -m pip install iisignature

- name: Test package
run: >-
Expand Down
79 changes: 75 additions & 4 deletions tests/test_against_iisignature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest
from numpy.random import default_rng

from signax import logsignature, signature
from signax import logsignature, multi_signature_combine, signature
from signax.utils import compress, lyndon_words, unravel_signature

rng = default_rng()

Expand All @@ -32,14 +33,84 @@ def test_signature(depth, length, dim, stream):
assert jnp.allclose(jax_signature, iis_signature)


def test_multi_signature_combine():
# iisignature does not support multiple signature combination
# we only test for the case of combining two signatures
# note: this test is passed in signatory before migrating to iisignature
n_signatures = 2
dim = 5
signatures = [
rng.standard_normal((n_signatures, dim)),
rng.standard_normal((n_signatures, dim, dim)),
rng.standard_normal((n_signatures, dim, dim, dim)),
]

jax_signatures = [jnp.array(x) for x in signatures]

jax_output = multi_signature_combine(jax_signatures)
jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_output])

iis_signatures = []
for i in range(n_signatures):
tensors = [np.asarray(x[i]) for x in signatures]
current = np.concatenate([t.flatten() for t in tensors])
current = current[None, :]
iis_signatures.append(current)

iis_output = iisignature.sigcombine(
iis_signatures[0],
iis_signatures[1],
dim,
len(signatures),
)
iis_output = jnp.array(iis_output)
assert jnp.allclose(jax_output, iis_output)


@pytest.mark.parametrize("stream", [True, False])
def test_signature_batch(stream):
depth = 3

# no remainder case
length = 1001
dim = 5
n_chunks = 10

path = rng.standard_normal((length, dim))
jax_signature = signature(
path, depth=depth, num_chunks=n_chunks, stream=stream, flatten=True
)

iis_signature = (
iisignature.sig(np.asarray(path), depth)
if not stream
else iisignature.sig(np.asarray(path), depth, 2)
)
iis_signature = jnp.asarray(iis_signature)

assert jnp.allclose(jax_signature, iis_signature)


@pytest.mark.parametrize(
("depth", "length", "dim", "stream"), [(1, 2, 2, False), (3, 3, 5, False)]
)
def test_logsignature(depth, length, dim, stream):
batch_size = 10
path = rng.standard_normal((batch_size, length, dim))
jax_logsignature = logsignature(path, depth=depth, stream=stream, flatten=True)
s = iisignature.prepare(dim, depth, "O")
iis_logsignature = iisignature.logsig(np.asarray(path), s)
iis_logsignature = jnp.asarray(iis_logsignature)

# get expanded version of log signature
s = iisignature.prepare(dim, depth, "x")
iis_logsignature = iisignature.logsig(np.asarray(path), s, "x")

def _compress(expanded_log_signature):
# convert expanded array as list of arrays
expanded_log_signature = unravel_signature(expanded_log_signature, dim, depth)
indices = lyndon_words(depth, dim)
compressed = compress(expanded_log_signature, indices)
compressed = jnp.concatenate(compressed)
return compressed

iis_logsignature = jax.vmap(_compress)(iis_logsignature)

assert jnp.allclose(jax_logsignature, iis_logsignature, atol=5e-1, rtol=5e-1)
49 changes: 26 additions & 23 deletions tests/test_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import numpy as np
from numpy.random import default_rng

from signax import signature, signature_to_logsignature
from signax.tensor_ops import (
addcmul,
mult,
mult_fused_restricted_exp,
otimes,
restricted_exp,
)
from signax.utils import compress, lyndon_words, unravel_signature

rng = default_rng()

Expand Down Expand Up @@ -110,26 +112,27 @@ def test_mult():
assert jnp.allclose(iis_signature, jax_output)


# def test_log():
# """Test log via signature_to_logsignature"""
# depth = 4
# length, dim = 3, 2
# path = rng.standard_normal((length, dim))
# jax_path = jnp.array(path)
# jax_signature = signature(jax_path, depth, flatten=False)
# jax_logsignature = signature_to_logsignature(jax_signature)
# jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_logsignature])

# torch_signature = signatory.signature(
# torch.tensor(path)[None, ...],
# depth,
# )
# torch_logsignature = signatory.signature_to_logsignature(
# torch_signature,
# dim,
# depth,
# )

# torch_output = jnp.array(torch_logsignature.numpy())

# assert jnp.allclose(torch_output, jax_output)
def test_log():
"""Test log via signature_to_logsignature"""
depth = 4
length, dim = 3, 2
path = rng.standard_normal((length, dim))
jax_path = jnp.array(path)
jax_signature = signature(jax_path, depth, flatten=False)
jax_logsignature = signature_to_logsignature(jax_signature)
jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_logsignature])

s = iisignature.prepare(dim, depth, "x")
iis_logsignature = iisignature.logsig(np.asarray(path), s, "x")

def _compress(expanded_log_signature):
# convert expanded array as list of arrays
expanded_log_signature = unravel_signature(expanded_log_signature, dim, depth)
indices = lyndon_words(depth, dim)
compressed = compress(expanded_log_signature, indices)
compressed = jnp.concatenate(compressed)
return compressed

iis_logsignature = _compress(iis_logsignature)

assert jnp.allclose(iis_logsignature, jax_output)