Skip to content

Commit

Permalink
add missing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
anh-tong committed Mar 27, 2024
1 parent 5958ee4 commit 28db27d
Showing 1 changed file with 59 additions and 1 deletion.
60 changes: 59 additions & 1 deletion tests/test_against_iisignature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from numpy.random import default_rng

from signax import logsignature, signature
from signax import logsignature, multi_signature_combine, signature
from signax.utils import compress, lyndon_words, unravel_signature

rng = default_rng()
Expand All @@ -33,6 +33,64 @@ def test_signature(depth, length, dim, stream):
assert jnp.allclose(jax_signature, iis_signature)


def test_multi_signature_combine():
# iisignature does not support multiple signature combination
# we only test for the case of combining two signatures
# note: this test is passed in signatory before migrating to iisignature
n_signatures = 2
dim = 5
signatures = [
rng.standard_normal((n_signatures, dim)),
rng.standard_normal((n_signatures, dim, dim)),
rng.standard_normal((n_signatures, dim, dim, dim)),
]

jax_signatures = [jnp.array(x) for x in signatures]

jax_output = multi_signature_combine(jax_signatures)
jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_output])

iis_signatures = []
for i in range(n_signatures):
tensors = [np.asarray(x[i]) for x in signatures]
current = np.concatenate([t.flatten() for t in tensors])
current = current[None, :]
iis_signatures.append(current)

iis_output = iisignature.sigcombine(
iis_signatures[0],
iis_signatures[1],
dim,
len(signatures),
)
iis_output = jnp.array(iis_output)
assert jnp.allclose(jax_output, iis_output)


@pytest.mark.parametrize("stream", [True, False])
def test_signature_batch(stream):
depth = 3

# no remainder case
length = 1001
dim = 5
n_chunks = 10

path = rng.standard_normal((length, dim))
jax_signature = signature(
path, depth=depth, num_chunks=n_chunks, stream=stream, flatten=True
)

iis_signature = (
iisignature.sig(np.asarray(path), depth)
if not stream
else iisignature.sig(np.asarray(path), depth, 2)
)
iis_signature = jnp.asarray(iis_signature)

assert jnp.allclose(jax_signature, iis_signature)


@pytest.mark.parametrize(
("depth", "length", "dim", "stream"), [(1, 2, 2, False), (3, 3, 5, False)]
)
Expand Down

0 comments on commit 28db27d

Please sign in to comment.