Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eqx.error_if errors even when predicate is false if input array is empty #835

Open
allen-adastra opened this issue Sep 6, 2024 · 5 comments
Labels
question User queries

Comments

@allen-adastra
Copy link

Hello!

The following currently errors out:

eqx.error_if((), jnp.array(False), "Errors?")

with the error:

.venv/lib/python3.11/site-packages/equinox/_errors.py:229: in error_if
    return branched_error_if(x, pred, 0, [msg], on_error=on_error)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

x = (), pred = Array(False, dtype=bool), index = 0, msgs = ['Errors?']

    @doc_remove_args("on_error")
    def branched_error_if(
        x: PyTree,
        pred: Bool[ArrayLike, "..."],
        index: Int[ArrayLike, "..."],
        msgs: Sequence[str],
        *,
        on_error: Literal["default", "raise", "breakpoint", "nan"] = "default",
    ) -> PyTree:
        """As [`equinox.error_if`][], but will raise one of
        several `msgs` depending on the value of `index`. If `index` is vmap'd, then the
        error message from the largest value (across the whole batch) will be used.
        """
        leaves = jtu.tree_leaves((x, pred, index))
        # This carefully does not perform any JAX operations if `pred` and `index` are
        # a bool and an int.
        # This ensures we can use `error_if` before init_google.
        if any(is_array(leaf) for leaf in leaves):
>           return branched_error_if_impl_jit(x, pred, index, msgs, on_error=on_error)
E           ValueError: No arrays to thread error on to.
E           --------------------
E           For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

This is the line of code where the error happens:

    if len(flat) == 0:
        raise ValueError("No arrays to thread error on to.")

I'm wondering if there's some way around this in the case that pred=False?

@allen-adastra
Copy link
Author

Another issue seems to be pure_callback can't handle arrays that have one of Jax's custom PRNG types.

The following fixes things for my purposes:

def wrapped_error_if(x, pred, msg):
    # The error_if function is not happy if we input prng_key types.
    x = eqx.filter(x, lambda leaf: not jnp.issubdtype(leaf.dtype, jax.dtypes.prng_key))

    # Provide a sentinel x.
    # Work-around until https://github.com/patrick-kidger/equinox/issues/835 is fixed.
    flat = jax.tree.leaves(x)
    if len(flat) == 0:
        sentinel_x = jnp.array(True)
        eqx.error_if(sentinel_x, pred, msg)
    else:
        eqx.error_if(x, pred, msg)

@allen-adastra
Copy link
Author

Whoop, needed a small addition :)

def patched_error_if(x, pred, msg):
    # The error_if function is not happy if we input prng_key types.
    # This is due to its calling of jax.pure_callback, which tries to use a pure numpy check on the array which has a custom JAX type.
    x = eqx.filter(x, lambda leaf: isinstance(leaf, Array) and not jnp.issubdtype(leaf.dtype, jax.dtypes.prng_key))

    # Provide a sentinel x.
    # Work-around until https://github.com/patrick-kidger/equinox/issues/835 is resolved.
    flat = jax.tree.leaves(x)
    if len(flat) == 0:
        sentinel_x = jnp.array(True)
        eqx.error_if(sentinel_x, pred, msg)
    else:
        eqx.error_if(x, pred, msg)

@patrick-kidger
Copy link
Owner

This is intentional. The array is needed to place the error-check in the right place in the right place in the program. I think special-casing static Falses here may lead to hard-to-understand behaviour. But if you want it then I imagine you should be able to accomplish this via something like:

def allen_error_if(x, pred, msg):
    if type(pred) is not bool and type(jax.core.get_aval(pred)) is jax.core.ConcreteArray:
        with jax.ensure_compile_time_eval():
            pred = pred.item()
    if pred is False:
        return x
    else:
        return eqx.error_if(x, pred, msg)

Note that your code above has a bug in that it does not return anything, so the error-check will be DCE'd.

@patrick-kidger patrick-kidger added the question User queries label Sep 7, 2024
@allen-adastra
Copy link
Author

This is intentional. The array is needed to place the error-check in the right place in the right place in the program. I think special-casing static Falses here may lead to hard-to-understand behaviour. But if you want it then I imagine you should be able to accomplish this via something like:

def allen_error_if(x, pred, msg):
    if type(pred) is not bool and type(jax.core.get_aval(pred)) is jax.core.ConcreteArray:
        with jax.ensure_compile_time_eval():
            pred = pred.item()
    if pred is False:
        return x
    else:
        return eqx.error_if(x, pred, msg)

Note that your code above has a bug in that it does not return anything, so the error-check will be DCE'd.

Hm, it seems like it would be generally helpful for error_if to handle empty PyTrees (or to have a canonical solution). Wondering what you would recommend?

@allen-adastra
Copy link
Author

This is intentional. The array is needed to place the error-check in the right place in the right place in the program. I think special-casing static Falses here may lead to hard-to-understand behaviour. But if you want it then I imagine you should be able to accomplish this via something like:

def allen_error_if(x, pred, msg):
    if type(pred) is not bool and type(jax.core.get_aval(pred)) is jax.core.ConcreteArray:
        with jax.ensure_compile_time_eval():
            pred = pred.item()
    if pred is False:
        return x
    else:
        return eqx.error_if(x, pred, msg)

Note that your code above has a bug in that it does not return anything, so the error-check will be DCE'd.

Re-visiting this, I think there was a mis-understanding. The issue is not pred, the issue is x being empty. For example, the following don't work:

eqx.error_if((), jnp.arange(10), "Errors?"
eqx.error_if(None, jnp.arange(10), "Errors?")
eqx.error_if([], jnp.arange(10), "Errors?")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants