diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e011bb9..780052a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -57,11 +57,6 @@ jobs: python -m pip install --upgrade setuptools pip wheel python -m pip install .[test] - - name: Install heavy deps for testing (signatory + torch) - run: | - python -m pip install torch==1.9.0 - python -m pip install signatory==1.2.6.1.9.0 --no-cache-dir --force-reinstall - - name: Test package run: >- python -m pytest -ra --cov --cov-report=xml --cov-report=term diff --git a/tests/test_against_iisignature.py b/tests/test_against_iisignature.py index 8175610..608f72e 100644 --- a/tests/test_against_iisignature.py +++ b/tests/test_against_iisignature.py @@ -8,6 +8,7 @@ from numpy.random import default_rng from signax import logsignature, signature +from signax.utils import compress, lyndon_words, unravel_signature rng = default_rng() @@ -39,7 +40,19 @@ def test_logsignature(depth, length, dim, stream): batch_size = 10 path = rng.standard_normal((batch_size, length, dim)) jax_logsignature = logsignature(path, depth=depth, stream=stream, flatten=True) - s = iisignature.prepare(dim, depth, "O") - iis_logsignature = iisignature.logsig(np.asarray(path), s) - iis_logsignature = jnp.asarray(iis_logsignature) + + # get expanded version of log signature + s = iisignature.prepare(dim, depth, "x") + iis_logsignature = iisignature.logsig(np.asarray(path), s, "x") + + def _compress(expanded_log_signature): + # convert expanded array as list of arrays + expanded_log_signature = unravel_signature(expanded_log_signature, dim, depth) + indices = lyndon_words(depth, dim) + compressed = compress(expanded_log_signature, indices) + compressed = jnp.concatenate(compressed) + return compressed + + iis_logsignature = jax.vmap(_compress)(iis_logsignature) + assert jnp.allclose(jax_logsignature, iis_logsignature, atol=5e-1, rtol=5e-1) diff --git a/tests/test_tensor_ops.py b/tests/test_tensor_ops.py index 5ab6458..3f246b3 100644 --- a/tests/test_tensor_ops.py +++ b/tests/test_tensor_ops.py @@ -6,6 +6,7 @@ import numpy as np from numpy.random import default_rng +from signax import signature, signature_to_logsignature from signax.tensor_ops import ( addcmul, mult, @@ -13,6 +14,7 @@ otimes, restricted_exp, ) +from signax.utils import compress, lyndon_words, unravel_signature rng = default_rng() @@ -110,26 +112,27 @@ def test_mult(): assert jnp.allclose(iis_signature, jax_output) -# def test_log(): -# """Test log via signature_to_logsignature""" -# depth = 4 -# length, dim = 3, 2 -# path = rng.standard_normal((length, dim)) -# jax_path = jnp.array(path) -# jax_signature = signature(jax_path, depth, flatten=False) -# jax_logsignature = signature_to_logsignature(jax_signature) -# jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_logsignature]) - -# torch_signature = signatory.signature( -# torch.tensor(path)[None, ...], -# depth, -# ) -# torch_logsignature = signatory.signature_to_logsignature( -# torch_signature, -# dim, -# depth, -# ) - -# torch_output = jnp.array(torch_logsignature.numpy()) - -# assert jnp.allclose(torch_output, jax_output) +def test_log(): + """Test log via signature_to_logsignature""" + depth = 4 + length, dim = 3, 2 + path = rng.standard_normal((length, dim)) + jax_path = jnp.array(path) + jax_signature = signature(jax_path, depth, flatten=False) + jax_logsignature = signature_to_logsignature(jax_signature) + jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_logsignature]) + + s = iisignature.prepare(dim, depth, "x") + iis_logsignature = iisignature.logsig(np.asarray(path), s, "x") + + def _compress(expanded_log_signature): + # convert expanded array as list of arrays + expanded_log_signature = unravel_signature(expanded_log_signature, dim, depth) + indices = lyndon_words(depth, dim) + compressed = compress(expanded_log_signature, indices) + compressed = jnp.concatenate(compressed) + return compressed + + iis_logsignature = _compress(iis_logsignature) + + assert jnp.allclose(iis_logsignature, jax_output)