Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow two variadic shapes when it makes sense #184

Open
kho opened this issue Mar 5, 2024 · 1 comment
Open

Allow two variadic shapes when it makes sense #184

kho opened this issue Mar 5, 2024 · 1 comment
Labels
feature New feature

Comments

@kho
Copy link

kho commented Mar 5, 2024

Currently if I write a function like the following:

def mask_invalid(
    x: Shaped[Array, '*B *C'], mask: Bool[Array, '*B']
) -> Shaped[Array, '*B *C']:
  return jnp.where(jnp.expand_dims(mask, range(mask.ndim, x.ndim)), x, 0)

I will get a "ValueError: Cannot use variadic specifiers (*name or ...) more than once.". That seems a bit overly restrictive: it's certainly possible to infer *C from the shape of mask. Can the requirement be relaxed to something like "only 1 non-determinable variadic specifier can be used in each shape", i.e. the following algorithm?

# Each shape is the list of dimension names and whether its variadic.
# Returns the non-determinable variadic names.
def eliminate_determinable_variadic_shapes(*shapes: Sequence[tuple[str, bool]]):
  remaining = set(range(len(shapes)))
  # Variadic dimension names that can be determined.
  deteriminable = set()
  while True:
    new_remaining = set()
    for i in remaining:
      variadics = [
          name
          for name, variadic in shapes[i]
          if variadic and name not in deteriminable
      ]
      if len(variadics) > 1:
        new_remaining.add(i)
      elif len(variadics) == 1:
        deteriminable.add(variadics[0])
        print(variadics[0], 'becomes determinable because of', shapes[i])
    if len(remaining) == len(new_remaining):
      return set([
          name
          for shape in shapes
          for name, variadic in shape
          if variadic and name not in deteriminable
      ])
    remaining = new_remaining
@patrick-kidger
Copy link
Owner

IIUC, you're basically trying to resolve the variadic shapes one-at-a-time?
Indeed that'd be possible in-principle but might be fairly tricky to implement -- right now we leave checking this to the runtime type checker, which is what gets to determine the order in which arguments are checked.

We don't have to do that. We could use the runtime type checker just to traverse all the annotations (i.e. to handle nested annotations like tuple[dict[str, Float[Array, ...), record all the annotations it comes across as the ones we want to check against, and then do shape-checks ourselves using an algorithm like the one you describe.

In practice I'm afraid that'd be a fair amount of work, that might end up being fairly fragile (it'd involve using a dynamic context I think), for quite a niche feature, I'm afraid.

@patrick-kidger patrick-kidger added the feature New feature label Mar 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants