Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with sampling Wishart Distribution #729

Open
njtierney opened this issue Oct 16, 2024 · 3 comments
Open

Fix issue with sampling Wishart Distribution #729

njtierney opened this issue Oct 16, 2024 · 3 comments

Comments

@njtierney
Copy link
Collaborator

Currently there's an issue with sampling a Wishart.

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
sigma <- matrix(
  data = c(1.2, 0.7, 0.7, 2.3),
  nrow = 2,
  ncol = 2
)

df <- 4

x <- wishart(df, sigma)[1, 2]
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 

m <- model(x)

draws <- mcmc(m)
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#>     warmup                                           0/1000 | eta:  ?s              warmup ==                                       50/1000 | eta: 11s | 14% bad    warmup ====                                    100/1000 | eta:  6s | 16% bad    warmup ======                                  150/1000 | eta:  5s | 12% bad    warmup ========                                200/1000 | eta:  4s | 9% bad     warmup ==========                              250/1000 | eta:  3s | 8% bad     warmup ===========                             300/1000 | eta:  3s | 7% bad     warmup =============                           350/1000 | eta:  2s | 6% bad     warmup ===============                         400/1000 | eta:  2s | 5% bad     warmup =================                       450/1000 | eta:  2s | 5% bad     warmup ===================                     500/1000 | eta:  2s | 4% bad     warmup =====================                   550/1000 | eta:  1s | 4% bad     warmup =======================                 600/1000 | eta:  1s | 4% bad     warmup =========================               650/1000 | eta:  1s | 3% bad     warmup ===========================             700/1000 | eta:  1s | 3% bad     warmup ============================            750/1000 | eta:  1s | 3% bad     warmup ==============================          800/1000 | eta:  1s | 3% bad     warmup ================================        850/1000 | eta:  0s | 3% bad     warmup ==================================      900/1000 | eta:  0s | 3% bad     warmup ====================================    950/1000 | eta:  0s | 2% bad     warmup ====================================== 1000/1000 | eta:  0s | 2% bad 
#>   sampling                                           0/1000 | eta:  ?s            sampling ==                                       50/1000 | eta:  1s            sampling ====                                    100/1000 | eta:  1s | <1% bad  sampling ======                                  150/1000 | eta:  1s | <1% bad  sampling ========                                200/1000 | eta:  1s | <1% bad  sampling ==========                              250/1000 | eta:  1s | <1% bad  sampling ===========                             300/1000 | eta:  1s | <1% bad  sampling =============                           350/1000 | eta:  1s | <1% bad  sampling ===============                         400/1000 | eta:  1s | <1% bad  sampling =================                       450/1000 | eta:  1s | <1% bad  sampling ===================                     500/1000 | eta:  1s | <1% bad  sampling =====================                   550/1000 | eta:  1s | <1% bad  sampling =======================                 600/1000 | eta:  1s | <1% bad  sampling =========================               650/1000 | eta:  1s | <1% bad  sampling ===========================             700/1000 | eta:  0s | <1% bad  sampling ============================            750/1000 | eta:  0s | <1% bad  sampling ==============================          800/1000 | eta:  0s | <1% bad  sampling ================================        850/1000 | eta:  0s | <1% bad  sampling ==================================      900/1000 | eta:  0s | <1% bad  sampling ====================================    950/1000 | eta:  0s | <1% bad  sampling ====================================== 1000/1000 | eta:  0s | <1% bad

plot(draws)

Created on 2024-10-16 with reprex v2.1.1

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 Patched (2024-07-08 r86915)
#>  os       macOS Sonoma 14.5
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Hobart
#>  date     2024-10-16
#>  pandoc   3.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/aarch64/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  abind         1.4-5      2016-07-21 [1] CRAN (R 4.4.0)
#>  backports     1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.4.0)
#>  callr         3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>  cli           3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
#>  coda          0.19-4.1   2024-01-31 [1] CRAN (R 4.4.0)
#>  codetools     0.2-20     2024-03-31 [2] CRAN (R 4.4.1)
#>  crayon        1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
#>  curl          5.2.1      2024-03-01 [1] CRAN (R 4.4.0)
#>  digest        0.6.36     2024-06-23 [1] CRAN (R 4.4.0)
#>  evaluate      0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
#>  fastmap       1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>  fs            1.6.4.9000 2024-06-26 [1] Github (r-lib/fs@714990b)
#>  future        1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>  globals       0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>  glue          1.7.0      2024-01-09 [1] CRAN (R 4.4.0)
#>  greta       * 0.5.0      2024-10-16 [1] local
#>  highr         0.11       2024-05-26 [1] CRAN (R 4.4.0)
#>  hms           1.1.3      2023-03-21 [1] CRAN (R 4.4.0)
#>  htmltools     0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>  jsonlite      1.8.8      2023-12-04 [1] CRAN (R 4.4.0)
#>  knitr         1.48       2024-07-07 [1] CRAN (R 4.4.0)
#>  lattice       0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>  lifecycle     1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>  listenv       0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>  magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>  Matrix        1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>  parallelly    1.38.0     2024-07-27 [1] CRAN (R 4.4.0)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>  png           0.1-8      2022-11-29 [1] CRAN (R 4.4.0)
#>  prettyunits   1.2.0      2023-09-24 [1] CRAN (R 4.4.0)
#>  processx      3.8.4      2024-03-16 [1] CRAN (R 4.4.0)
#>  progress      1.2.3      2023-12-06 [1] CRAN (R 4.4.0)
#>  ps            1.8.0      2024-09-12 [1] CRAN (R 4.4.1)
#>  R6            2.5.1      2021-08-19 [1] CRAN (R 4.4.0)
#>  Rcpp          1.0.13     2024-07-17 [1] CRAN (R 4.4.0)
#>  reprex        2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>  reticulate    1.38.0     2024-06-19 [1] CRAN (R 4.4.0)
#>  rlang         1.1.4      2024-06-04 [1] CRAN (R 4.4.0)
#>  rmarkdown     2.27       2024-05-17 [1] CRAN (R 4.4.0)
#>  rstudioapi    0.16.0     2024-03-24 [1] CRAN (R 4.4.0)
#>  sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>  tensorflow    2.16.0     2024-04-15 [1] CRAN (R 4.4.0)
#>  tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.4.0)
#>  tfruns        1.5.3      2024-04-19 [1] CRAN (R 4.4.0)
#>  vctrs         0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>  whisker       0.4.1      2022-12-05 [1] CRAN (R 4.4.0)
#>  withr         3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
#>  xfun          0.46       2024-07-18 [1] CRAN (R 4.4.0)
#>  xml2          1.3.6      2023-12-04 [1] CRAN (R 4.4.0)
#>  yaml          2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/nick/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#>  numpy_version:  1.26.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python() function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

It seems we aren't transforming the free state correctly to something that is positive semi-definite, which is what is needed for this function that takes the cholesky factor:

x_chol <- tf$linalg$cholesky(x)

The transformation part happens in the bijector, the functions that are responsible for transforming the input values into the right scale during MCMC, then transforms them back into the right scale.

The one we use in Wishart is this:

greta/R/tf_functions.R

Lines 705 to 707 in 595f7b2

tf_covariance_cholesky_bijector <- function() {
tfp$bijectors$FillTriangular(upper = TRUE)
}

Which is called in create_tf_bijector

covariance_matrix = tf_covariance_cholesky_bijector(),

Which is used in tf_log_jacobian_adjustment

greta/R/node_types.R

Lines 356 to 363 in cb14e95

tf_from_free = function(x) {
tf_bijector <- self$create_tf_bijector()
tf_bijector$forward(x)
},
# adjustments for univariate variables
tf_log_jacobian_adjustment = function(free) {
tf_bijector <- self$create_tf_bijector()

Anyway the result of this is that we don't evaluate the log prob properly, so we do good sampling, as shown above.

It's quite possible we can swap out

tfp.bijectors.FillTriangular for tfp.bijectors.FillScaleTriL

@njtierney
Copy link
Collaborator Author

Here is before changes:

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
sigma <- matrix(
  data = c(1.2, 0.7, 0.7, 2.3),
  nrow = 2,
  ncol = 2
)

df <- 4

x <- wishart(df, sigma)[1, 2]
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 

m <- model(x)

draws <- mcmc(m)
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#>     warmup                                           0/1000 | eta:  ?s              warmup ==                                       50/1000 | eta: 12s | 12% bad    warmup ====                                    100/1000 | eta:  7s | 14% bad    warmup ======                                  150/1000 | eta:  5s | 10% bad    warmup ========                                200/1000 | eta:  4s | 8% bad     warmup ==========                              250/1000 | eta:  3s | 6% bad     warmup ===========                             300/1000 | eta:  3s | 6% bad     warmup =============                           350/1000 | eta:  2s | 5% bad     warmup ===============                         400/1000 | eta:  2s | 5% bad     warmup =================                       450/1000 | eta:  2s | 4% bad     warmup ===================                     500/1000 | eta:  2s | 4% bad     warmup =====================                   550/1000 | eta:  1s | 4% bad     warmup =======================                 600/1000 | eta:  1s | 3% bad     warmup =========================               650/1000 | eta:  1s | 3% bad     warmup ===========================             700/1000 | eta:  1s | 3% bad     warmup ============================            750/1000 | eta:  1s | 3% bad     warmup ==============================          800/1000 | eta:  1s | 3% bad     warmup ================================        850/1000 | eta:  0s | 3% bad     warmup ==================================      900/1000 | eta:  0s | 3% bad     warmup ====================================    950/1000 | eta:  0s | 3% bad     warmup ====================================== 1000/1000 | eta:  0s | 3% bad 
#>   sampling                                           0/1000 | eta:  ?s            sampling ==                                       50/1000 | eta:  1s | 1% bad   sampling ====                                    100/1000 | eta:  1s | 2% bad   sampling ======                                  150/1000 | eta:  1s | 1% bad   sampling ========                                200/1000 | eta:  1s | 1% bad   sampling ==========                              250/1000 | eta:  1s | 1% bad   sampling ===========                             300/1000 | eta:  1s | 1% bad   sampling =============                           350/1000 | eta:  1s | 1% bad   sampling ===============                         400/1000 | eta:  1s | 1% bad   sampling =================                       450/1000 | eta:  1s | 1% bad   sampling ===================                     500/1000 | eta:  1s | 1% bad   sampling =====================                   550/1000 | eta:  1s | 1% bad   sampling =======================                 600/1000 | eta:  1s | 1% bad   sampling =========================               650/1000 | eta:  1s | 1% bad   sampling ===========================             700/1000 | eta:  0s | 1% bad   sampling ============================            750/1000 | eta:  0s | 1% bad   sampling ==============================          800/1000 | eta:  0s | 1% bad   sampling ================================        850/1000 | eta:  0s | 1% bad   sampling ==================================      900/1000 | eta:  0s | 1% bad   sampling ====================================    950/1000 | eta:  0s | 1% bad   sampling ====================================== 1000/1000 | eta:  0s | 1% bad

plot(draws)

coda::gelman.diag(draws)
#> Potential scale reduction factors:
#> 
#>   Point est. Upper C.I.
#> x       7.29       13.5

Created on 2024-10-16 with reprex v2.1.1

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 Patched (2024-07-08 r86915)
#>  os       macOS Sonoma 14.5
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Hobart
#>  date     2024-10-16
#>  pandoc   3.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/aarch64/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  abind         1.4-5      2016-07-21 [1] CRAN (R 4.4.0)
#>  backports     1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.4.0)
#>  callr         3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>  cli           3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
#>  coda          0.19-4.1   2024-01-31 [1] CRAN (R 4.4.0)
#>  codetools     0.2-20     2024-03-31 [2] CRAN (R 4.4.1)
#>  crayon        1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
#>  curl          5.2.1      2024-03-01 [1] CRAN (R 4.4.0)
#>  digest        0.6.36     2024-06-23 [1] CRAN (R 4.4.0)
#>  evaluate      0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
#>  fastmap       1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>  fs            1.6.4.9000 2024-06-26 [1] Github (r-lib/fs@714990b)
#>  future        1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>  globals       0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>  glue          1.7.0      2024-01-09 [1] CRAN (R 4.4.0)
#>  greta       * 0.5.0      2024-10-16 [1] local
#>  highr         0.11       2024-05-26 [1] CRAN (R 4.4.0)
#>  hms           1.1.3      2023-03-21 [1] CRAN (R 4.4.0)
#>  htmltools     0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>  jsonlite      1.8.8      2023-12-04 [1] CRAN (R 4.4.0)
#>  knitr         1.48       2024-07-07 [1] CRAN (R 4.4.0)
#>  lattice       0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>  lifecycle     1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>  listenv       0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>  magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>  Matrix        1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>  parallelly    1.38.0     2024-07-27 [1] CRAN (R 4.4.0)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>  png           0.1-8      2022-11-29 [1] CRAN (R 4.4.0)
#>  prettyunits   1.2.0      2023-09-24 [1] CRAN (R 4.4.0)
#>  processx      3.8.4      2024-03-16 [1] CRAN (R 4.4.0)
#>  progress      1.2.3      2023-12-06 [1] CRAN (R 4.4.0)
#>  ps            1.8.0      2024-09-12 [1] CRAN (R 4.4.1)
#>  R6            2.5.1      2021-08-19 [1] CRAN (R 4.4.0)
#>  Rcpp          1.0.13     2024-07-17 [1] CRAN (R 4.4.0)
#>  reprex        2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>  reticulate    1.38.0     2024-06-19 [1] CRAN (R 4.4.0)
#>  rlang         1.1.4      2024-06-04 [1] CRAN (R 4.4.0)
#>  rmarkdown     2.27       2024-05-17 [1] CRAN (R 4.4.0)
#>  rstudioapi    0.16.0     2024-03-24 [1] CRAN (R 4.4.0)
#>  sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>  tensorflow    2.16.0     2024-04-15 [1] CRAN (R 4.4.0)
#>  tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.4.0)
#>  tfruns        1.5.3      2024-04-19 [1] CRAN (R 4.4.0)
#>  vctrs         0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>  whisker       0.4.1      2022-12-05 [1] CRAN (R 4.4.0)
#>  withr         3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
#>  xfun          0.46       2024-07-18 [1] CRAN (R 4.4.0)
#>  xml2          1.3.6      2023-12-04 [1] CRAN (R 4.4.0)
#>  yaml          2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/nick/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#>  numpy_version:  1.26.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python() function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

And after changes

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
sigma <- matrix(
  data = c(1.2, 0.7, 0.7, 2.3),
  nrow = 2,
  ncol = 2
)

df <- 4

x <- wishart(df, sigma)[1, 2]
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 

m <- model(x)

draws <- mcmc(m)
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#>     warmup                                           0/1000 | eta:  ?s              warmup ==                                       50/1000 | eta: 17s              warmup ====                                    100/1000 | eta:  9s              warmup ======                                  150/1000 | eta:  7s              warmup ========                                200/1000 | eta:  5s              warmup ==========                              250/1000 | eta:  4s              warmup ===========                             300/1000 | eta:  4s              warmup =============                           350/1000 | eta:  3s              warmup ===============                         400/1000 | eta:  3s              warmup =================                       450/1000 | eta:  2s              warmup ===================                     500/1000 | eta:  2s              warmup =====================                   550/1000 | eta:  2s              warmup =======================                 600/1000 | eta:  1s              warmup =========================               650/1000 | eta:  1s              warmup ===========================             700/1000 | eta:  1s              warmup ============================            750/1000 | eta:  1s              warmup ==============================          800/1000 | eta:  1s              warmup ================================        850/1000 | eta:  0s              warmup ==================================      900/1000 | eta:  0s              warmup ====================================    950/1000 | eta:  0s              warmup ====================================== 1000/1000 | eta:  0s          
#>   sampling                                           0/1000 | eta:  ?s            sampling ==                                       50/1000 | eta:  1s            sampling ====                                    100/1000 | eta:  1s            sampling ======                                  150/1000 | eta:  1s            sampling ========                                200/1000 | eta:  1s            sampling ==========                              250/1000 | eta:  1s            sampling ===========                             300/1000 | eta:  1s            sampling =============                           350/1000 | eta:  1s            sampling ===============                         400/1000 | eta:  1s            sampling =================                       450/1000 | eta:  1s            sampling ===================                     500/1000 | eta:  1s            sampling =====================                   550/1000 | eta:  1s            sampling =======================                 600/1000 | eta:  1s            sampling =========================               650/1000 | eta:  1s            sampling ===========================             700/1000 | eta:  0s            sampling ============================            750/1000 | eta:  0s            sampling ==============================          800/1000 | eta:  0s            sampling ================================        850/1000 | eta:  0s            sampling ==================================      900/1000 | eta:  0s            sampling ====================================    950/1000 | eta:  0s            sampling ====================================== 1000/1000 | eta:  0s

plot(draws)

coda::gelman.diag(draws)
#> Potential scale reduction factors:
#> 
#>   Point est. Upper C.I.
#> x       4.01       13.7

Created on 2024-10-16 with reprex v2.1.1

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 Patched (2024-07-08 r86915)
#>  os       macOS Sonoma 14.5
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Hobart
#>  date     2024-10-16
#>  pandoc   3.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/aarch64/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  abind         1.4-5      2016-07-21 [1] CRAN (R 4.4.0)
#>  backports     1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.4.0)
#>  callr         3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>  cli           3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
#>  coda          0.19-4.1   2024-01-31 [1] CRAN (R 4.4.0)
#>  codetools     0.2-20     2024-03-31 [2] CRAN (R 4.4.1)
#>  crayon        1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
#>  curl          5.2.1      2024-03-01 [1] CRAN (R 4.4.0)
#>  digest        0.6.36     2024-06-23 [1] CRAN (R 4.4.0)
#>  evaluate      0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
#>  fastmap       1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>  fs            1.6.4.9000 2024-06-26 [1] Github (r-lib/fs@714990b)
#>  future        1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>  globals       0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>  glue          1.7.0      2024-01-09 [1] CRAN (R 4.4.0)
#>  greta       * 0.5.0      2024-10-16 [1] local
#>  highr         0.11       2024-05-26 [1] CRAN (R 4.4.0)
#>  hms           1.1.3      2023-03-21 [1] CRAN (R 4.4.0)
#>  htmltools     0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>  jsonlite      1.8.8      2023-12-04 [1] CRAN (R 4.4.0)
#>  knitr         1.48       2024-07-07 [1] CRAN (R 4.4.0)
#>  lattice       0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>  lifecycle     1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>  listenv       0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>  magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>  Matrix        1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>  parallelly    1.38.0     2024-07-27 [1] CRAN (R 4.4.0)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>  png           0.1-8      2022-11-29 [1] CRAN (R 4.4.0)
#>  prettyunits   1.2.0      2023-09-24 [1] CRAN (R 4.4.0)
#>  processx      3.8.4      2024-03-16 [1] CRAN (R 4.4.0)
#>  progress      1.2.3      2023-12-06 [1] CRAN (R 4.4.0)
#>  ps            1.8.0      2024-09-12 [1] CRAN (R 4.4.1)
#>  R6            2.5.1      2021-08-19 [1] CRAN (R 4.4.0)
#>  Rcpp          1.0.13     2024-07-17 [1] CRAN (R 4.4.0)
#>  reprex        2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>  reticulate    1.38.0     2024-06-19 [1] CRAN (R 4.4.0)
#>  rlang         1.1.4      2024-06-04 [1] CRAN (R 4.4.0)
#>  rmarkdown     2.27       2024-05-17 [1] CRAN (R 4.4.0)
#>  rstudioapi    0.16.0     2024-03-24 [1] CRAN (R 4.4.0)
#>  sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>  tensorflow    2.16.0     2024-04-15 [1] CRAN (R 4.4.0)
#>  tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.4.0)
#>  tfruns        1.5.3      2024-04-19 [1] CRAN (R 4.4.0)
#>  vctrs         0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>  whisker       0.4.1      2022-12-05 [1] CRAN (R 4.4.0)
#>  withr         3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
#>  xfun          0.46       2024-07-18 [1] CRAN (R 4.4.0)
#>  xml2          1.3.6      2023-12-04 [1] CRAN (R 4.4.0)
#>  yaml          2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/nick/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#>  numpy_version:  1.26.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python() function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

You can see it is you scroll all the way to the right, but you lose the "X% bad" samples part with the changes.

@njtierney
Copy link
Collaborator Author

Another demonstration of the issues in the log prob.

Before

  devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
sigma <- matrix(
  data = c(1.2, 0.7, 0.7, 2.3),
  nrow = 2,
  ncol = 2
)
df <- 4
x <- wishart(df, sigma)[1, 2]
m <- model(x)
new_log_prob <- m$dag$generate_log_prob_function()
m$dag$define_tf_log_prob_function()
prob_input <- matrix(rnorm(12), 4, 3) # this gives us `x_chol` with nan...
# prob_input <- matrix(runif(12), 4, 3)
new_log_prob(prob_input)
#> $adjusted
#> tf.Tensor([nan nan nan nan], shape=(4), dtype=float64)
#> 
#> $unadjusted
#> tf.Tensor([nan nan nan nan], shape=(4), dtype=float64)

Created on 2024-10-16 with reprex v2.1.1

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 Patched (2024-07-08 r86915)
#>  os       macOS Sonoma 14.5
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Hobart
#>  date     2024-10-16
#>  pandoc   3.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/aarch64/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package     * version    date (UTC) lib source
#>    abind         1.4-5      2016-07-21 [1] CRAN (R 4.4.0)
#>    backports     1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>    base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.4.0)
#>    brio          1.1.5      2024-04-24 [1] CRAN (R 4.4.0)
#>    cachem        1.1.0      2024-05-16 [1] CRAN (R 4.4.0)
#>    callr         3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>    cli           3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
#>    coda          0.19-4.1   2024-01-31 [1] CRAN (R 4.4.0)
#>    codetools     0.2-20     2024-03-31 [2] CRAN (R 4.4.1)
#>    crayon        1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
#>    desc          1.4.3      2023-12-10 [1] CRAN (R 4.4.0)
#>    devtools      2.4.5      2022-10-11 [1] CRAN (R 4.4.0)
#>    digest        0.6.36     2024-06-23 [1] CRAN (R 4.4.0)
#>    ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.4.0)
#>    evaluate      0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
#>    fastmap       1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>    fs            1.6.4.9000 2024-06-26 [1] Github (r-lib/fs@714990b)
#>    future        1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>    globals       0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>    glue          1.7.0      2024-01-09 [1] CRAN (R 4.4.0)
#>  P greta       * 0.5.0      2024-10-16 [?] load_all()
#>    hms           1.1.3      2023-03-21 [1] CRAN (R 4.4.0)
#>    htmltools     0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>    htmlwidgets   1.6.4      2023-12-06 [1] CRAN (R 4.4.0)
#>    httpuv        1.6.15     2024-03-26 [1] CRAN (R 4.4.0)
#>    jsonlite      1.8.8      2023-12-04 [1] CRAN (R 4.4.0)
#>    knitr         1.48       2024-07-07 [1] CRAN (R 4.4.0)
#>    later         1.3.2      2023-12-06 [1] CRAN (R 4.4.0)
#>    lattice       0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>    lifecycle     1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>    listenv       0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>    magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>    Matrix        1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>    memoise       2.0.1.9000 2024-08-14 [1] Github (hadley/memoise@40db995)
#>    mime          0.12       2021-09-28 [1] CRAN (R 4.4.0)
#>    miniUI        0.1.1.1    2018-05-18 [1] CRAN (R 4.4.0)
#>    parallelly    1.38.0     2024-07-27 [1] CRAN (R 4.4.0)
#>    pkgbuild      1.4.4      2024-03-17 [1] CRAN (R 4.4.0)
#>    pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>    pkgload       1.4.0      2024-06-28 [1] CRAN (R 4.4.0)
#>    png           0.1-8      2022-11-29 [1] CRAN (R 4.4.0)
#>    prettyunits   1.2.0      2023-09-24 [1] CRAN (R 4.4.0)
#>    processx      3.8.4      2024-03-16 [1] CRAN (R 4.4.0)
#>    profvis       0.3.8      2023-05-02 [1] CRAN (R 4.4.0)
#>    progress      1.2.3      2023-12-06 [1] CRAN (R 4.4.0)
#>    promises      1.3.0      2024-04-05 [1] CRAN (R 4.4.0)
#>    ps            1.8.0      2024-09-12 [1] CRAN (R 4.4.1)
#>    purrr         1.0.2      2023-08-10 [1] CRAN (R 4.4.0)
#>    R6            2.5.1      2021-08-19 [1] CRAN (R 4.4.0)
#>    Rcpp          1.0.13     2024-07-17 [1] CRAN (R 4.4.0)
#>    remotes       2.5.0      2024-03-17 [1] CRAN (R 4.4.0)
#>    reprex        2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>    reticulate    1.38.0     2024-06-19 [1] CRAN (R 4.4.0)
#>    rlang         1.1.4      2024-06-04 [1] CRAN (R 4.4.0)
#>    rmarkdown     2.27       2024-05-17 [1] CRAN (R 4.4.0)
#>    rprojroot     2.0.4      2023-11-05 [1] CRAN (R 4.4.0)
#>    rstudioapi    0.16.0     2024-03-24 [1] CRAN (R 4.4.0)
#>    sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>    shiny         1.9.1      2024-08-01 [1] CRAN (R 4.4.0)
#>    stringi       1.8.4      2024-05-06 [1] CRAN (R 4.4.0)
#>    stringr       1.5.1      2023-11-14 [1] CRAN (R 4.4.0)
#>    tensorflow    2.16.0     2024-04-15 [1] CRAN (R 4.4.0)
#>    testthat    * 3.2.1.1    2024-04-14 [1] CRAN (R 4.4.0)
#>    tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.4.0)
#>    tfruns        1.5.3      2024-04-19 [1] CRAN (R 4.4.0)
#>    urlchecker    1.0.1.9000 2024-08-27 [1] Github (r-lib/urlchecker@ac38ea4)
#>    usethis       3.0.0      2024-07-29 [1] CRAN (R 4.4.0)
#>    vctrs         0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>    whisker       0.4.1      2022-12-05 [1] CRAN (R 4.4.0)
#>    withr         3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
#>    xfun          0.46       2024-07-18 [1] CRAN (R 4.4.0)
#>    xtable        1.8-4      2019-04-21 [1] CRAN (R 4.4.0)
#>    yaml          2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#>    yesno         0.1.3      2024-07-26 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/nick/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#>  P ── Loaded and on-disk path mismatch.
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#>  numpy_version:  1.26.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python() function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

After

  devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
sigma <- matrix(
  data = c(1.2, 0.7, 0.7, 2.3),
  nrow = 2,
  ncol = 2
)
df <- 4
x <- wishart(df, sigma)[1, 2]
m <- model(x)
new_log_prob <- m$dag$generate_log_prob_function()
m$dag$define_tf_log_prob_function()
prob_input <- matrix(rnorm(12), 4, 3) # this gives us `x_chol` with nan...
# prob_input <- matrix(runif(12), 4, 3)
new_log_prob(prob_input)
#> $adjusted
#> tf.Tensor([-9.03816533 -7.69794824 -7.43463759 -9.22124165], shape=(4), dtype=float64)
#> 
#> $unadjusted
#> tf.Tensor([-6.29694024 -5.88441816 -5.81974113 -6.38780931], shape=(4), dtype=float64)

Created on 2024-10-16 with reprex v2.1.1

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 Patched (2024-07-08 r86915)
#>  os       macOS Sonoma 14.5
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Hobart
#>  date     2024-10-16
#>  pandoc   3.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/aarch64/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package     * version    date (UTC) lib source
#>    abind         1.4-5      2016-07-21 [1] CRAN (R 4.4.0)
#>    backports     1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>    base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.4.0)
#>    brio          1.1.5      2024-04-24 [1] CRAN (R 4.4.0)
#>    cachem        1.1.0      2024-05-16 [1] CRAN (R 4.4.0)
#>    callr         3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>    cli           3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
#>    coda          0.19-4.1   2024-01-31 [1] CRAN (R 4.4.0)
#>    codetools     0.2-20     2024-03-31 [2] CRAN (R 4.4.1)
#>    crayon        1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
#>    desc          1.4.3      2023-12-10 [1] CRAN (R 4.4.0)
#>    devtools      2.4.5      2022-10-11 [1] CRAN (R 4.4.0)
#>    digest        0.6.36     2024-06-23 [1] CRAN (R 4.4.0)
#>    ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.4.0)
#>    evaluate      0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
#>    fastmap       1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>    fs            1.6.4.9000 2024-06-26 [1] Github (r-lib/fs@714990b)
#>    future        1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>    globals       0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>    glue          1.7.0      2024-01-09 [1] CRAN (R 4.4.0)
#>  P greta       * 0.5.0      2024-10-16 [?] load_all()
#>    hms           1.1.3      2023-03-21 [1] CRAN (R 4.4.0)
#>    htmltools     0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>    htmlwidgets   1.6.4      2023-12-06 [1] CRAN (R 4.4.0)
#>    httpuv        1.6.15     2024-03-26 [1] CRAN (R 4.4.0)
#>    jsonlite      1.8.8      2023-12-04 [1] CRAN (R 4.4.0)
#>    knitr         1.48       2024-07-07 [1] CRAN (R 4.4.0)
#>    later         1.3.2      2023-12-06 [1] CRAN (R 4.4.0)
#>    lattice       0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>    lifecycle     1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>    listenv       0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>    magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>    Matrix        1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>    memoise       2.0.1.9000 2024-08-14 [1] Github (hadley/memoise@40db995)
#>    mime          0.12       2021-09-28 [1] CRAN (R 4.4.0)
#>    miniUI        0.1.1.1    2018-05-18 [1] CRAN (R 4.4.0)
#>    parallelly    1.38.0     2024-07-27 [1] CRAN (R 4.4.0)
#>    pkgbuild      1.4.4      2024-03-17 [1] CRAN (R 4.4.0)
#>    pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>    pkgload       1.4.0      2024-06-28 [1] CRAN (R 4.4.0)
#>    png           0.1-8      2022-11-29 [1] CRAN (R 4.4.0)
#>    prettyunits   1.2.0      2023-09-24 [1] CRAN (R 4.4.0)
#>    processx      3.8.4      2024-03-16 [1] CRAN (R 4.4.0)
#>    profvis       0.3.8      2023-05-02 [1] CRAN (R 4.4.0)
#>    progress      1.2.3      2023-12-06 [1] CRAN (R 4.4.0)
#>    promises      1.3.0      2024-04-05 [1] CRAN (R 4.4.0)
#>    ps            1.8.0      2024-09-12 [1] CRAN (R 4.4.1)
#>    purrr         1.0.2      2023-08-10 [1] CRAN (R 4.4.0)
#>    R6            2.5.1      2021-08-19 [1] CRAN (R 4.4.0)
#>    Rcpp          1.0.13     2024-07-17 [1] CRAN (R 4.4.0)
#>    remotes       2.5.0      2024-03-17 [1] CRAN (R 4.4.0)
#>    reprex        2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>    reticulate    1.38.0     2024-06-19 [1] CRAN (R 4.4.0)
#>    rlang         1.1.4      2024-06-04 [1] CRAN (R 4.4.0)
#>    rmarkdown     2.27       2024-05-17 [1] CRAN (R 4.4.0)
#>    rprojroot     2.0.4      2023-11-05 [1] CRAN (R 4.4.0)
#>    rstudioapi    0.16.0     2024-03-24 [1] CRAN (R 4.4.0)
#>    sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>    shiny         1.9.1      2024-08-01 [1] CRAN (R 4.4.0)
#>    stringi       1.8.4      2024-05-06 [1] CRAN (R 4.4.0)
#>    stringr       1.5.1      2023-11-14 [1] CRAN (R 4.4.0)
#>    tensorflow    2.16.0     2024-04-15 [1] CRAN (R 4.4.0)
#>    testthat    * 3.2.1.1    2024-04-14 [1] CRAN (R 4.4.0)
#>    tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.4.0)
#>    tfruns        1.5.3      2024-04-19 [1] CRAN (R 4.4.0)
#>    urlchecker    1.0.1.9000 2024-08-27 [1] Github (r-lib/urlchecker@ac38ea4)
#>    usethis       3.0.0      2024-07-29 [1] CRAN (R 4.4.0)
#>    vctrs         0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>    whisker       0.4.1      2022-12-05 [1] CRAN (R 4.4.0)
#>    withr         3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
#>    xfun          0.46       2024-07-18 [1] CRAN (R 4.4.0)
#>    xtable        1.8-4      2019-04-21 [1] CRAN (R 4.4.0)
#>    yaml          2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#>    yesno         0.1.3      2024-07-26 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/nick/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#>  P ── Loaded and on-disk path mismatch.
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#>  numpy_version:  1.26.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python() function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

@njtierney
Copy link
Collaborator Author

While this is improved in #730 it is not resolved, the bias in this indicates something deeper is wrong:

devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
sigma <- matrix(
  data = c(1.2, 0.7, 0.7, 2.3),
  nrow = 2,
  ncol = 2
)

df <- 4

x <- wishart(df, sigma)[1, 2]

iid_wish <- function(n) {
  rWishart(n, df, sigma)[1, 2, ]
}
m <- model(x)
n_draws <- 5000
draws <- mcmc(m, warmup = 2000, n_samples = n_draws)
#> running 4 chains simultaneously on up to 8 CPU cores
#>     warmup                                           0/2000 | eta:  ?s              warmup =                                        50/2000 | eta: 36s              warmup ==                                      100/2000 | eta: 20s              warmup ===                                     150/2000 | eta: 14s              warmup ====                                    200/2000 | eta: 11s              warmup =====                                   250/2000 | eta: 10s              warmup ======                                  300/2000 | eta:  9s              warmup =======                                 350/2000 | eta:  8s              warmup ========                                400/2000 | eta:  8s              warmup =========                               450/2000 | eta:  7s              warmup ==========                              500/2000 | eta:  6s              warmup ==========                              550/2000 | eta:  6s              warmup ===========                             600/2000 | eta:  5s              warmup ============                            650/2000 | eta:  5s              warmup =============                           700/2000 | eta:  5s              warmup ==============                          750/2000 | eta:  5s              warmup ===============                         800/2000 | eta:  4s              warmup ================                        850/2000 | eta:  4s              warmup =================                       900/2000 | eta:  4s              warmup ==================                      950/2000 | eta:  3s              warmup ===================                    1000/2000 | eta:  3s              warmup ====================                   1050/2000 | eta:  3s              warmup =====================                  1100/2000 | eta:  3s              warmup ======================                 1150/2000 | eta:  3s              warmup =======================                1200/2000 | eta:  2s              warmup ========================               1250/2000 | eta:  2s              warmup =========================              1300/2000 | eta:  2s              warmup ==========================             1350/2000 | eta:  2s              warmup ===========================            1400/2000 | eta:  2s              warmup ============================           1450/2000 | eta:  2s              warmup ============================           1500/2000 | eta:  1s              warmup =============================          1550/2000 | eta:  1s              warmup ==============================         1600/2000 | eta:  1s              warmup ===============================        1650/2000 | eta:  1s              warmup ================================       1700/2000 | eta:  1s              warmup =================================      1750/2000 | eta:  1s              warmup ==================================     1800/2000 | eta:  1s              warmup ===================================    1850/2000 | eta:  0s              warmup ====================================   1900/2000 | eta:  0s              warmup =====================================  1950/2000 | eta:  0s              warmup ====================================== 2000/2000 | eta:  0s          
#>   sampling                                           0/5000 | eta:  ?s            sampling                                          50/5000 | eta:  6s            sampling =                                       100/5000 | eta:  6s            sampling =                                       150/5000 | eta:  7s            sampling ==                                      200/5000 | eta:  7s            sampling ==                                      250/5000 | eta:  8s            sampling ==                                      300/5000 | eta:  8s            sampling ===                                     350/5000 | eta:  8s            sampling ===                                     400/5000 | eta:  7s            sampling ===                                     450/5000 | eta:  7s            sampling ====                                    500/5000 | eta:  7s            sampling ====                                    550/5000 | eta:  7s            sampling =====                                   600/5000 | eta:  7s            sampling =====                                   650/5000 | eta:  7s            sampling =====                                   700/5000 | eta:  7s            sampling ======                                  750/5000 | eta:  7s            sampling ======                                  800/5000 | eta:  7s            sampling ======                                  850/5000 | eta:  7s            sampling =======                                 900/5000 | eta:  7s            sampling =======                                 950/5000 | eta:  7s            sampling ========                               1000/5000 | eta:  7s            sampling ========                               1050/5000 | eta:  7s            sampling ========                               1100/5000 | eta:  7s            sampling =========                              1150/5000 | eta:  7s            sampling =========                              1200/5000 | eta:  7s            sampling ==========                             1250/5000 | eta:  7s            sampling ==========                             1300/5000 | eta:  6s            sampling ==========                             1350/5000 | eta:  6s            sampling ===========                            1400/5000 | eta:  6s            sampling ===========                            1450/5000 | eta:  6s            sampling ===========                            1500/5000 | eta:  6s            sampling ============                           1550/5000 | eta:  6s            sampling ============                           1600/5000 | eta:  6s            sampling =============                          1650/5000 | eta:  6s            sampling =============                          1700/5000 | eta:  6s            sampling =============                          1750/5000 | eta:  6s            sampling ==============                         1800/5000 | eta:  6s            sampling ==============                         1850/5000 | eta:  5s            sampling ==============                         1900/5000 | eta:  5s            sampling ===============                        1950/5000 | eta:  5s            sampling ===============                        2000/5000 | eta:  5s            sampling ================                       2050/5000 | eta:  5s            sampling ================                       2100/5000 | eta:  5s            sampling ================                       2150/5000 | eta:  5s            sampling =================                      2200/5000 | eta:  5s            sampling =================                      2250/5000 | eta:  5s            sampling =================                      2300/5000 | eta:  5s            sampling ==================                     2350/5000 | eta:  5s            sampling ==================                     2400/5000 | eta:  4s            sampling ===================                    2450/5000 | eta:  4s            sampling ===================                    2500/5000 | eta:  4s            sampling ===================                    2550/5000 | eta:  4s            sampling ====================                   2600/5000 | eta:  4s            sampling ====================                   2650/5000 | eta:  4s            sampling =====================                  2700/5000 | eta:  4s            sampling =====================                  2750/5000 | eta:  4s            sampling =====================                  2800/5000 | eta:  4s            sampling ======================                 2850/5000 | eta:  4s            sampling ======================                 2900/5000 | eta:  4s            sampling ======================                 2950/5000 | eta:  4s            sampling =======================                3000/5000 | eta:  3s            sampling =======================                3050/5000 | eta:  3s            sampling ========================               3100/5000 | eta:  3s            sampling ========================               3150/5000 | eta:  3s            sampling ========================               3200/5000 | eta:  3s            sampling =========================              3250/5000 | eta:  3s            sampling =========================              3300/5000 | eta:  3s            sampling =========================              3350/5000 | eta:  3s            sampling ==========================             3400/5000 | eta:  3s            sampling ==========================             3450/5000 | eta:  3s            sampling ===========================            3500/5000 | eta:  3s            sampling ===========================            3550/5000 | eta:  2s            sampling ===========================            3600/5000 | eta:  2s            sampling ============================           3650/5000 | eta:  2s            sampling ============================           3700/5000 | eta:  2s            sampling ============================           3750/5000 | eta:  2s            sampling =============================          3800/5000 | eta:  2s            sampling =============================          3850/5000 | eta:  2s            sampling ==============================         3900/5000 | eta:  2s            sampling ==============================         3950/5000 | eta:  2s            sampling ==============================         4000/5000 | eta:  2s            sampling ===============================        4050/5000 | eta:  2s            sampling ===============================        4100/5000 | eta:  2s            sampling ================================       4150/5000 | eta:  1s            sampling ================================       4200/5000 | eta:  1s            sampling ================================       4250/5000 | eta:  1s            sampling =================================      4300/5000 | eta:  1s            sampling =================================      4350/5000 | eta:  1s            sampling =================================      4400/5000 | eta:  1s            sampling ==================================     4450/5000 | eta:  1s            sampling ==================================     4500/5000 | eta:  1s            sampling ===================================    4550/5000 | eta:  1s            sampling ===================================    4600/5000 | eta:  1s            sampling ===================================    4650/5000 | eta:  1s            sampling ====================================   4700/5000 | eta:  1s            sampling ====================================   4750/5000 | eta:  0s            sampling ====================================   4800/5000 | eta:  0s            sampling =====================================  4850/5000 | eta:  0s            sampling =====================================  4900/5000 | eta:  0s            sampling ====================================== 4950/5000 | eta:  0s            sampling ====================================== 5000/5000 | eta:  0s
iid_draws <- iid_wish(n = n_draws)
plot(draws)

coda::gelman.diag(draws)
#> Potential scale reduction factors:
#> 
#>   Point est. Upper C.I.
#> x       6.63       12.8

stats::qqplot(
  x = as.matrix(draws),
  y = iid_draws,
  main = "Wishart draws with HMC"
)

graphics::abline(0, 1)

Created on 2024-10-18 with reprex v2.1.1

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 Patched (2024-07-08 r86915)
#>  os       macOS Sonoma 14.5
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Hobart
#>  date     2024-10-18
#>  pandoc   3.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/aarch64/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package     * version    date (UTC) lib source
#>    abind         1.4-5      2016-07-21 [1] CRAN (R 4.4.0)
#>    backports     1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>    base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.4.0)
#>    brio          1.1.5      2024-04-24 [1] CRAN (R 4.4.0)
#>    cachem        1.1.0      2024-05-16 [1] CRAN (R 4.4.0)
#>    callr         3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>    cli           3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
#>    coda          0.19-4.1   2024-01-31 [1] CRAN (R 4.4.0)
#>    codetools     0.2-20     2024-03-31 [2] CRAN (R 4.4.1)
#>    crayon        1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
#>    curl          5.2.1      2024-03-01 [1] CRAN (R 4.4.0)
#>    desc          1.4.3      2023-12-10 [1] CRAN (R 4.4.0)
#>    devtools      2.4.5      2022-10-11 [1] CRAN (R 4.4.0)
#>    digest        0.6.36     2024-06-23 [1] CRAN (R 4.4.0)
#>    ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.4.0)
#>    evaluate      0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
#>    fastmap       1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>    fs            1.6.4.9000 2024-06-26 [1] Github (r-lib/fs@714990b)
#>    future        1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>    globals       0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>    glue          1.7.0      2024-01-09 [1] CRAN (R 4.4.0)
#>  P greta       * 0.5.0      2024-10-16 [?] load_all()
#>    highr         0.11       2024-05-26 [1] CRAN (R 4.4.0)
#>    hms           1.1.3      2023-03-21 [1] CRAN (R 4.4.0)
#>    htmltools     0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>    htmlwidgets   1.6.4      2023-12-06 [1] CRAN (R 4.4.0)
#>    httpuv        1.6.15     2024-03-26 [1] CRAN (R 4.4.0)
#>    jsonlite      1.8.8      2023-12-04 [1] CRAN (R 4.4.0)
#>    knitr         1.48       2024-07-07 [1] CRAN (R 4.4.0)
#>    later         1.3.2      2023-12-06 [1] CRAN (R 4.4.0)
#>    lattice       0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>    lifecycle     1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>    listenv       0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>    magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>    Matrix        1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>    memoise       2.0.1.9000 2024-08-14 [1] Github (hadley/memoise@40db995)
#>    mime          0.12       2021-09-28 [1] CRAN (R 4.4.0)
#>    miniUI        0.1.1.1    2018-05-18 [1] CRAN (R 4.4.0)
#>    parallelly    1.38.0     2024-07-27 [1] CRAN (R 4.4.0)
#>    pkgbuild      1.4.4      2024-03-17 [1] CRAN (R 4.4.0)
#>    pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>    pkgload       1.4.0      2024-06-28 [1] CRAN (R 4.4.0)
#>    png           0.1-8      2022-11-29 [1] CRAN (R 4.4.0)
#>    prettyunits   1.2.0      2023-09-24 [1] CRAN (R 4.4.0)
#>    processx      3.8.4      2024-03-16 [1] CRAN (R 4.4.0)
#>    profvis       0.3.8      2023-05-02 [1] CRAN (R 4.4.0)
#>    progress      1.2.3      2023-12-06 [1] CRAN (R 4.4.0)
#>    promises      1.3.0      2024-04-05 [1] CRAN (R 4.4.0)
#>    ps            1.8.0      2024-09-12 [1] CRAN (R 4.4.1)
#>    purrr         1.0.2      2023-08-10 [1] CRAN (R 4.4.0)
#>    R6            2.5.1      2021-08-19 [1] CRAN (R 4.4.0)
#>    Rcpp          1.0.13     2024-07-17 [1] CRAN (R 4.4.0)
#>    remotes       2.5.0      2024-03-17 [1] CRAN (R 4.4.0)
#>    reprex        2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>    reticulate    1.38.0     2024-06-19 [1] CRAN (R 4.4.0)
#>    rlang         1.1.4      2024-06-04 [1] CRAN (R 4.4.0)
#>    rmarkdown     2.27       2024-05-17 [1] CRAN (R 4.4.0)
#>    rprojroot     2.0.4      2023-11-05 [1] CRAN (R 4.4.0)
#>    rstudioapi    0.16.0     2024-03-24 [1] CRAN (R 4.4.0)
#>    sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>    shiny         1.9.1      2024-08-01 [1] CRAN (R 4.4.0)
#>    stringi       1.8.4      2024-05-06 [1] CRAN (R 4.4.0)
#>    stringr       1.5.1      2023-11-14 [1] CRAN (R 4.4.0)
#>    tensorflow    2.16.0     2024-04-15 [1] CRAN (R 4.4.0)
#>    testthat    * 3.2.1.1    2024-04-14 [1] CRAN (R 4.4.0)
#>    tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.4.0)
#>    tfruns        1.5.3      2024-04-19 [1] CRAN (R 4.4.0)
#>    urlchecker    1.0.1.9000 2024-08-27 [1] Github (r-lib/urlchecker@ac38ea4)
#>    usethis       3.0.0      2024-07-29 [1] CRAN (R 4.4.0)
#>    vctrs         0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>    whisker       0.4.1      2022-12-05 [1] CRAN (R 4.4.0)
#>    withr         3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
#>    xfun          0.46       2024-07-18 [1] CRAN (R 4.4.0)
#>    xml2          1.3.6      2023-12-04 [1] CRAN (R 4.4.0)
#>    xtable        1.8-4      2019-04-21 [1] CRAN (R 4.4.0)
#>    yaml          2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#>    yesno         0.1.3      2024-07-26 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/nick/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#>  P ── Loaded and on-disk path mismatch.
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#>  numpy_version:  1.26.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python() function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In Progress
Development

No branches or pull requests

1 participant