Skip to content

Commit

Permalink
fix test and CI
Browse files Browse the repository at this point in the history
  • Loading branch information
anh-tong committed Mar 27, 2024
1 parent 539e160 commit 7588949
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 31 deletions.
5 changes: 0 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ jobs:
python -m pip install --upgrade setuptools pip wheel
python -m pip install .[test]
- name: Install heavy deps for testing (signatory + torch)
run: |
python -m pip install torch==1.9.0
python -m pip install signatory==1.2.6.1.9.0 --no-cache-dir --force-reinstall
- name: Test package
run: >-
python -m pytest -ra --cov --cov-report=xml --cov-report=term
Expand Down
19 changes: 16 additions & 3 deletions tests/test_against_iisignature.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from numpy.random import default_rng

from signax import logsignature, signature
from signax.utils import compress, lyndon_words, unravel_signature

rng = default_rng()

Expand Down Expand Up @@ -39,7 +40,19 @@ def test_logsignature(depth, length, dim, stream):
batch_size = 10
path = rng.standard_normal((batch_size, length, dim))
jax_logsignature = logsignature(path, depth=depth, stream=stream, flatten=True)
s = iisignature.prepare(dim, depth, "O")
iis_logsignature = iisignature.logsig(np.asarray(path), s)
iis_logsignature = jnp.asarray(iis_logsignature)

# get expanded version of log signature
s = iisignature.prepare(dim, depth, "x")
iis_logsignature = iisignature.logsig(np.asarray(path), s, "x")

def _compress(expanded_log_signature):
# convert expanded array as list of arrays
expanded_log_signature = unravel_signature(expanded_log_signature, dim, depth)
indices = lyndon_words(depth, dim)
compressed = compress(expanded_log_signature, indices)
compressed = jnp.concatenate(compressed)
return compressed

iis_logsignature = jax.vmap(_compress)(iis_logsignature)

assert jnp.allclose(jax_logsignature, iis_logsignature, atol=5e-1, rtol=5e-1)
49 changes: 26 additions & 23 deletions tests/test_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import numpy as np
from numpy.random import default_rng

from signax import signature, signature_to_logsignature
from signax.tensor_ops import (
addcmul,
mult,
mult_fused_restricted_exp,
otimes,
restricted_exp,
)
from signax.utils import compress, lyndon_words, unravel_signature

rng = default_rng()

Expand Down Expand Up @@ -110,26 +112,27 @@ def test_mult():
assert jnp.allclose(iis_signature, jax_output)


# def test_log():
# """Test log via signature_to_logsignature"""
# depth = 4
# length, dim = 3, 2
# path = rng.standard_normal((length, dim))
# jax_path = jnp.array(path)
# jax_signature = signature(jax_path, depth, flatten=False)
# jax_logsignature = signature_to_logsignature(jax_signature)
# jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_logsignature])

# torch_signature = signatory.signature(
# torch.tensor(path)[None, ...],
# depth,
# )
# torch_logsignature = signatory.signature_to_logsignature(
# torch_signature,
# dim,
# depth,
# )

# torch_output = jnp.array(torch_logsignature.numpy())

# assert jnp.allclose(torch_output, jax_output)
def test_log():
"""Test log via signature_to_logsignature"""
depth = 4
length, dim = 3, 2
path = rng.standard_normal((length, dim))
jax_path = jnp.array(path)
jax_signature = signature(jax_path, depth, flatten=False)
jax_logsignature = signature_to_logsignature(jax_signature)
jax_output = jnp.concatenate([jnp.ravel(x) for x in jax_logsignature])

s = iisignature.prepare(dim, depth, "x")
iis_logsignature = iisignature.logsig(np.asarray(path), s, "x")

def _compress(expanded_log_signature):
# convert expanded array as list of arrays
expanded_log_signature = unravel_signature(expanded_log_signature, dim, depth)
indices = lyndon_words(depth, dim)
compressed = compress(expanded_log_signature, indices)
compressed = jnp.concatenate(compressed)
return compressed

iis_logsignature = _compress(iis_logsignature)

assert jnp.allclose(iis_logsignature, jax_output)

0 comments on commit 7588949

Please sign in to comment.