Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor logp in BG/BB to remove Scan #703

Open
ColtAllen opened this issue May 26, 2024 · 5 comments · May be fixed by #707
Open

Refactor logp in BG/BB to remove Scan #703

ColtAllen opened this issue May 26, 2024 · 5 comments · May be fixed by #707
Labels
CLV enhancement New feature or request help wanted Extra attention is needed priority: low

Comments

@ColtAllen
Copy link
Collaborator

logp in the BetaGeoBetaBinom distribution block contains an iterable currently serviced by a Scan from pytensor. It's possible to refactor this so that Scan is no longer needed:

i = pt.scalar("i", dtype=int)
died = pt.lt(t_x + i, T)

unnorm_logp_died_at_tx_plus_i = pt.where(
    pt.ge(t_x, i),
    (
        betaln(alpha + x, beta + t_x - x + i)
        + betaln(gamma + died, delta + t_x + i)
    ),
    -np.inf
)

#Maximum prevents invalid T - t_x values from crashing logp
max_range = pt.maximum(pt.max(T - t_x), 0)
i_vec = pt.arange(max_range + 1)
unnorm_logp_died_at_tx_plus_i_vec = vectorize_graph(
    unnorm_logp_died_at_tx_plus_i,
    replace={i: i_vec},
)

unnorm_logp = pt.logsumexp(unnorm_logp_died_at_tx_plus_i_vec, axis=0)

I compared both approaches in a dev notebook, and sans Scan is about 3x faster:

# w/ Scan
267 ms ± 6.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# w/o Scan
85.2 ms ± 339 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

However, the above code requires modification because tests are failing with the returned logp values.

@ColtAllen ColtAllen added enhancement New feature or request help wanted Extra attention is needed CLV priority: low labels May 26, 2024
@juanitorduz juanitorduz linked a pull request May 30, 2024 that will close this issue
@ricardoV94
Copy link
Contributor

ricardoV94 commented May 31, 2024

Scan may be plenty fast in other backends: numba and jax, the first will be the default sometime in the future, and it's what it's used with nutpie. Jax is used for numpyro and blackjax. I would benchmark on those backends that before bothering to get rid of it.

Also for varied datasets (t_x very different across subjects) the non scan will probably be slower as it does a lot of useless computations. In the dense/ non scan way it will evaluate the worst case scenario (the biggest gap between T and t_x) for everyone even if it's only needed for 1 row out of 10000

@juanitorduz
Copy link
Collaborator

ok! thanks for the input! I took the PR because I always wanna play with scan, but we can close it and have other benchmarks. We can always come back and change it, as we have the code in a branch already.

@ColtAllen
Copy link
Collaborator Author

@ricardoV94 do you have a time estimate on when numba will became the new default backend? I'm working on the BG/BB model right now, and currently NUTS is taking over an hour on my Macbook M2 Pro with a dataset of 11.2k rows.

@ricardoV94
Copy link
Contributor

ricardoV94 commented Jul 22, 2024

You can select other backends manually, don't need to wait for the default to change

@ColtAllen ColtAllen mentioned this issue Aug 11, 2024
13 tasks
@juanitorduz
Copy link
Collaborator

Rescuing key commits from #707

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLV enhancement New feature or request help wanted Extra attention is needed priority: low
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants