Skip to content

Commit

Permalink
feat: retry failed tasks (#21)
Browse files Browse the repository at this point in the history
* feat: retry failed tasks

* refactor: optimize redis call
  • Loading branch information
be-marc authored Dec 3, 2023
1 parent 96f8ef9 commit 23fe3db
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 133 deletions.
102 changes: 65 additions & 37 deletions R/Rush.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,18 @@ Rush = R6::R6Class("Rush",
heartbeat_expire = NULL,
lgr_thresholds = NULL,
lgr_buffer_size = 0,
max_retries = 0,
supervise = TRUE,
worker_loop = worker_loop_default,
...
) {
n_workers = assert_count(n_workers %??% rush_env$n_workers)
assert_flag(wait_for_workers)
assert_flag(supervise)
r = self$connector

# set global maximum retries of tasks
private$.max_retries = assert_count(max_retries)

# push worker config to redis
private$.push_worker_config(
Expand All @@ -196,6 +201,7 @@ Rush = R6::R6Class("Rush",
heartbeat_expire = heartbeat_expire,
lgr_thresholds = lgr_thresholds,
lgr_buffer_size = lgr_buffer_size,
max_retries = max_retries,
worker_loop = worker_loop,
...
)
Expand Down Expand Up @@ -306,7 +312,7 @@ Rush = R6::R6Class("Rush",

# Push terminate signal to worker
cmds = map(worker_ids, function(worker_id) {
c("SET", private$.get_worker_key("terminate", worker_id), "TRUE")
c("SET", private$.get_worker_key("terminate", worker_id), "1")
})
r$pipeline(.commands = cmds)

Expand Down Expand Up @@ -353,10 +359,11 @@ Rush = R6::R6Class("Rush",
#' Workers with a heartbeat process are checked with the heartbeat.
#' Lost tasks are marked as `"lost"`.
#'
#' @param restart (`logical(1)`)\cr
#' @param restart_workers (`logical(1)`)\cr
#' Whether to restart lost workers.
detect_lost_workers = function(restart = FALSE) {
assert_flag(restart)
detect_lost_workers = function(restart_workers = FALSE, restart_tasks = FALSE) {
assert_flag(restart_workers)
assert_flag(restart_tasks)
r = self$connector

# check workers with a heartbeat
Expand Down Expand Up @@ -392,7 +399,7 @@ Rush = R6::R6Class("Rush",
lost_workers = local_workers[!running]
lg$error("Lost %i worker(s): %s", length(lost_workers), str_collapse(lost_workers))

if (restart) {
if (restart_workers) {
self$restart_workers(unlist(lost_workers))
lost_workers
} else {
Expand All @@ -412,17 +419,34 @@ Rush = R6::R6Class("Rush",
if (length(lost_workers)) {
running_tasks = self$fetch_running_tasks(fields = "worker_extra")
if (!nrow(running_tasks)) return(invisible(self))
keys = running_tasks[lost_workers, keys, on = "worker_id"]
lost_workers = unlist(lost_workers)
keys = running_tasks[list(lost_workers), keys, on = "worker_id"]

lg$error("Lost %i task(s): %s", length(keys), str_collapse(keys))

cmds = unlist(map(keys, function(key) {
list(
list("HSET", key, "state", failed_state),
c("SREM", private$.get_key("running_tasks"), key),
c("RPUSH", private$.get_key("failed_tasks"), key))
}), recursive = FALSE)
r$pipeline(.commands = cmds)
if (restart_tasks) {

# check whether the tasks should be retried
retry = self$n_tries(keys) < private$.max_retries
keys = keys[retry]

if (length(keys)) {
lg$error("Retry %i lost task(s): %s", length(keys), str_collapse(keys))
cmds = map(keys, function(key) {
c("HINCRBY", key, "n_tries", 1)
})
cmds = c(cmds, list(
c("RPUSH", private$.get_key("queued_tasks"), keys),
c("SREM", private$.get_key("running_tasks"), keys)
))
r$pipeline(.commands = cmds)
}
} else {
cmds = list(
c("RPUSH", private$.get_key("failed_tasks"), keys),
c("SREM", private$.get_key("running_tasks"), keys))
r$pipeline(.commands = cmds)
}
}

return(invisible(self))
Expand Down Expand Up @@ -523,10 +547,10 @@ Rush = R6::R6Class("Rush",

lg$debug("Pushing %i task(s) to the shared queue", length(xss))

keys = self$write_hashes(xs = xss, xs_extra = extra, state = "queued")
keys = self$write_hashes(xs = xss, xs_extra = extra)
r$command(c("LPUSH", private$.get_key("queued_tasks"), keys))
r$command(c("SADD", private$.get_key("all_tasks"), keys))
if (terminate_workers) r$command(c("SET", private$.get_key("terminate_on_idle"), "TRUE"))
if (terminate_workers) r$command(c("SET", private$.get_key("terminate_on_idle"), 1))

return(invisible(keys))
},
Expand Down Expand Up @@ -559,7 +583,7 @@ Rush = R6::R6Class("Rush",
lg$debug("Pushing %i task(s) to %i priority queue(s) and %i task(s) to the shared queue.",
sum(!is.na(priority)), length(unique(priority[!is.na(priority)])), sum(is.na(priority)))

keys = self$write_hashes(xs = xss, xs_extra = extra, state = "queued")
keys = self$write_hashes(xs = xss, xs_extra = extra)
cmds = pmap(list(priority, keys), function(worker_id, key) {
if (is.na(worker_id)) {
c("LPUSH", private$.get_key("queued_tasks"), key)
Expand Down Expand Up @@ -824,25 +848,14 @@ Rush = R6::R6Class("Rush",
#' @param keys (character())\cr
#' Keys of the hashes.
#' If `NULL` new keys are generated.
#' @param state (`character(1)`)\cr
#' State of the hashes.
#'
#' @return (`character()`)\cr
#' Keys of the hashes.
write_hashes = function(..., .values = list(), keys = NULL, state = NA_character_) {
write_hashes = function(..., .values = list(), keys = NULL) {
values = discard(c(list(...), .values), function(l) !length(l))
assert_list(values, names = "unique", types = "list", min.len = 1)
fields = names(values)
keys = assert_character(keys %??% uuid::UUIDgenerate(n = length(values[[1]])), len = length(values[[1]]), .var.name = "keys")
assert_string(state, na.ok = TRUE)
bin_state = switch(state,
"queued" = queued_state,
"running" = running_state,
"failed" = failed_state,
"finished" = finished_state,
`NA_character_` = na_state,
redux::object_to_bin(list(state = state))
)

lg$debug("Writting %i hash(es) with %i field(s)", length(keys), length(fields))

Expand All @@ -856,7 +869,7 @@ Rush = R6::R6Class("Rush",
# merge fields and values alternatively
# c and rbind are fastest option in R
# data is not copied
c("HSET", key, c(rbind(fields, bin_values)), "state", list(bin_state))
c("HSET", key, c(rbind(fields, bin_values)))
})

self$connector$pipeline(.commands = cmds)
Expand Down Expand Up @@ -893,6 +906,23 @@ Rush = R6::R6Class("Rush",
# unserialize lists of the second level
# combine elements of the third level to one list
map(hashes, function(hash) unlist(map_if(hash, function(x) !is.null(x), redux::bin_to_object), recursive = FALSE))
},

#' @description
#' Returns the number of attempts to evaluate a task.
#'
#' @param keys (`character()`)\cr
#' Keys of the tasks.
#'
#' @return (`integer()`)\cr
#' Number of attempts.
n_tries = function(keys) {
assert_character(keys)
r = self$connector

# n_retries is not set when the task never failed before
n_tries = r$pipeline(.commands = map(keys, function(key) c("HGET", key, "n_tries")))
map_int(n_tries, function(value) if (is.null(value)) 0L else as.integer(value))
}
),

Expand Down Expand Up @@ -1119,6 +1149,8 @@ Rush = R6::R6Class("Rush",
#
.hostname = NULL,

.max_retries = NULL,

# prefix key with instance id
.get_key = function(key) {
sprintf("%s:%s", self$network_id, key)
Expand All @@ -1138,6 +1170,7 @@ Rush = R6::R6Class("Rush",
heartbeat_expire = NULL,
lgr_thresholds = NULL,
lgr_buffer_size = 0,
max_retries = 0,
worker_loop = worker_loop_default,
...
) {
Expand All @@ -1148,6 +1181,7 @@ Rush = R6::R6Class("Rush",
if (!is.null(heartbeat_period)) require_namespaces("callr")
assert_vector(lgr_thresholds, names = "named", null.ok = TRUE)
assert_count(lgr_buffer_size)
assert_count(max_retries)
assert_function(worker_loop)
dots = list(...)
r = self$connector
Expand All @@ -1166,7 +1200,8 @@ Rush = R6::R6Class("Rush",
heartbeat_period = heartbeat_period,
heartbeat_expire = heartbeat_expire,
lgr_thresholds = lgr_thresholds,
lgr_buffer_size = lgr_buffer_size)
lgr_buffer_size = lgr_buffer_size,
max_retries = max_retries)

# arguments needed for initializing the worker
start_args = list(
Expand Down Expand Up @@ -1241,10 +1276,3 @@ Rush = R6::R6Class("Rush",
)
)

# common state for all tasks
# used in $write_hashes()
queued_state = redux::object_to_bin(list(state = "queued"))
running_state = redux::object_to_bin(list(state = "running"))
failed_state = redux::object_to_bin(list(state = "failed"))
finished_state = redux::object_to_bin(list(state = "finished"))
na_state = redux::object_to_bin(list(state = NA_character_))
41 changes: 28 additions & 13 deletions R/RushWorker.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ RushWorker = R6::R6Class("RushWorker",
heartbeat_period = NULL,
heartbeat_expire = NULL,
lgr_thresholds = NULL,
lgr_buffer_size = 0
lgr_buffer_size = 0,
max_retries = 0
) {
super$initialize(network_id = network_id, config = config)

self$host = assert_choice(host, c("local", "remote"))
self$worker_id = assert_string(worker_id %??% uuid::UUIDgenerate())
private$.max_retries = assert_count(max_retries)
r = self$connector

# setup heartbeat
Expand Down Expand Up @@ -137,7 +139,7 @@ RushWorker = R6::R6Class("RushWorker",

lg$debug("Pushing %i running task(s).", length(xss))

keys = self$write_hashes(xs = xss, xs_extra = extra, state = "running")
keys = self$write_hashes(xs = xss, xs_extra = extra)
r$command(c("SADD", private$.get_key("running_tasks"), keys))
r$command(c("SADD", private$.get_key("all_tasks"), keys))

Expand All @@ -156,7 +158,7 @@ RushWorker = R6::R6Class("RushWorker",
key = r$command(c("BLMPOP", timeout, 2, private$.get_worker_key("queued_tasks"), private$.get_key("queued_tasks"), "RIGHT"))[[2]][[1]]

if (is.null(key)) return(NULL)
self$write_hashes(worker_extra = list(list(pid = Sys.getpid(), worker_id = self$worker_id)), keys = key, state = "running")
self$write_hashes(worker_extra = list(list(pid = Sys.getpid(), worker_id = self$worker_id)), keys = key)

# move key from queued to running
r$command(c("SADD", private$.get_key("running_tasks"), key))
Expand All @@ -178,28 +180,41 @@ RushWorker = R6::R6Class("RushWorker",
#' Status of the tasks.
#' If `"finished"` the tasks are moved to the finished tasks.
#' If `"error"` the tasks are moved to the failed tasks.
push_results = function(keys, yss = list(), extra = list(), conditions = list(), state = "finished") {
push_results = function(keys, yss = list(), extra = list(), conditions = list()) {
assert_string(keys)
assert_list(yss, types = "list")
assert_list(extra, types = "list")
assert_list(conditions, types = "list")
assert_choice(state, c("finished", "failed"))
r = self$connector

# write result to hash
self$write_hashes(ys = yss, ys_extra = extra, condition = conditions, keys = keys, state = state)

destination = if (state == "finished") "finished_tasks" else "failed_tasks"
self$write_hashes(ys = yss, ys_extra = extra, condition = conditions, keys = keys)

# move key from running to finished or failed
# move key from running to finished
# keys of finished and failed tasks are stored in a list i.e. the are ordered by time.
# each rush instance only needs to record how many results it has already seen
# to cheaply get the latest results and cache the finished tasks
# under some conditions a set would be more advantageous e.g. to check if a task is finished,
# but at the moment a list seems to be the better option
r$pipeline(.commands = list(
c("SREM", private$.get_key("running_tasks"), keys),
c("RPUSH", private$.get_key(destination), keys)
c("RPUSH", private$.get_key("finished_tasks"), keys)
))

return(invisible(self))
},

push_failed = function(keys, conditions) {
assert_string(keys)
assert_list(conditions, types = "list")
r = self$connector

# write condition to hash
self$write_hashes(condition = conditions, keys = keys)

# move key from running to failed
r$pipeline(.commands = list(
c("SREM", private$.get_key("running_tasks"), keys),
c("RPUSH", private$.get_key("failed_tasks"), keys)
))

return(invisible(self))
Expand All @@ -223,15 +238,15 @@ RushWorker = R6::R6Class("RushWorker",
#' Used in the worker loop to determine whether to continue.
terminated = function() {
r = self$connector
r$GET(private$.get_worker_key("terminate")) %??% "FALSE" == "TRUE"
as.logical(r$EXISTS(private$.get_worker_key("terminate")))
},

#' @field terminated_on_idle (`logical(1)`)\cr
#' Whether to shutdown the worker if no tasks are queued.
#' Used in the worker loop to determine whether to continue.
terminated_on_idle = function() {
r = self$connector
r$GET(private$.get_key("terminate_on_idle")) %??% "FALSE" == "TRUE" && !as.logical(self$n_queued_tasks)
as.logical(r$EXISTS(private$.get_key("terminate_on_idle"))) && !as.logical(self$n_queued_tasks)
}
)
)
2 changes: 1 addition & 1 deletion R/worker_loops.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ worker_loop_default = function(fun, constants = NULL, rush) {
rush$push_results(task$key, yss = list(ys))
}, error = function(e) {
condition = list(message = e$message)
rush$push_results(task$key, conditions = list(condition), state = "failed")
rush$push_failed(task$key, conditions = list(condition))
})
} else {
if (rush$terminated_on_idle) break
Expand Down
Loading

0 comments on commit 23fe3db

Please sign in to comment.