Skip to content

Commit

Permalink
fix: copy rng state to callr session (#104)
Browse files Browse the repository at this point in the history
* fix: copy rng state to callr session

* fix: handle no seed
  • Loading branch information
be-marc authored Apr 9, 2024
1 parent c48fb32 commit 45471b0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
18 changes: 14 additions & 4 deletions R/encapsulate.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ encapsulate = function(method, .f, .args = list(), .opts = list(), .pkgs = chara
} else { # method == "callr"
require_namespaces("callr")

# callr does not copy the RNG state, so we need to do it manually
.rng_state = .GlobalEnv$.Random.seed
logfile = tempfile()
now = proc.time()[3L]
result = try(callr::r(callr_wrapper,
list(.f = .f, .args = .args, .opts = .opts, .pkgs = .pkgs, .seed = .seed),
list(.f = .f, .args = .args, .opts = .opts, .pkgs = .pkgs, .seed = .seed, .rng_state = .rng_state),
stdout = logfile, stderr = logfile, timeout = .timeout), silent = TRUE)
elapsed = proc.time()[3L] - now

Expand All @@ -116,8 +118,10 @@ encapsulate = function(method, .f, .args = list(), .opts = list(), .pkgs = chara
log = c(log, sprintf("[ERR] callr process exited with status %i", status))
}
result = NULL
} else {
if (!is.null(result$rng_state)) assign(".Random.seed", result$rng_state, envir = globalenv())
result = result$result
}

log = parse_callr(log)
}

Expand Down Expand Up @@ -163,7 +167,7 @@ parse_callr = function(log) {
log[]
}

callr_wrapper = function(.f, .args, .opts, .pkgs, .seed) {
callr_wrapper = function(.f, .args, .opts, .pkgs, .seed, .rng_state) {
suppressPackageStartupMessages({
lapply(.pkgs, requireNamespace)
})
Expand All @@ -173,7 +177,10 @@ callr_wrapper = function(.f, .args, .opts, .pkgs, .seed) {
set.seed(.seed)
}

withCallingHandlers(
# restore RNG state from parent R session
if (!is.null(.rng_state)) assign(".Random.seed", .rng_state, envir = globalenv())

result = withCallingHandlers(
tryCatch(do.call(.f, .args),
error = function(e) {
cat("[ERR]", gsub("\r?\n|\r", "<br>", conditionMessage(e)), "\n")
Expand All @@ -185,4 +192,7 @@ callr_wrapper = function(.f, .args, .opts, .pkgs, .seed) {
invokeRestart("muffleWarning")
}
)

# copy new RNG state back to parent R session
list(result = result, rng_state = .GlobalEnv$.Random.seed)
}
41 changes: 41 additions & 0 deletions tests/testthat/test_encapsulate.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,44 @@ test_that("try", {
expect_message(encapsulate("try", function(...) message("foo")))
expect_warning(encapsulate("try", function(...) warning("foo")))
})

test_that("callr rng state", {

rng_state = .GlobalEnv$.Random.seed
on.exit({.GlobalEnv$.Random.seed = rng_state})

fun = function() {
sample(seq(1000), 1)
}

# no seed
res = encapsulate("callr", fun)
expect_number(res$result)

set.seed(1, kind = "Mersenne-Twister")
res = encapsulate("callr", fun)
expect_equal(res$result, 836)
expect_equal(sample(seq(1000), 1), 679)

set.seed(1, kind = "Mersenne-Twister")
expect_equal(fun(), 836)
expect_equal(sample(seq(1000), 1), 679)

set.seed(1, kind = "Wichmann-Hill")
res = encapsulate("callr", fun)
expect_equal(res$result, 309)
expect_equal(sample(seq(1000), 1), 885)

set.seed(1, kind = "Wichmann-Hill")
expect_equal(fun(), 309)
expect_equal(sample(seq(1000), 1), 885)

set.seed(1, kind = "L'Ecuyer-CMRG")
res = encapsulate("callr", fun)
expect_equal(res$result, 371)
expect_equal(sample(seq(1000), 1), 359)

set.seed(1, kind = "L'Ecuyer-CMRG")
expect_equal(fun(), 371)
expect_equal(sample(seq(1000), 1), 359)
})

0 comments on commit 45471b0

Please sign in to comment.